from
configparser
import
ConfigParser
from
datetime
import
datetime
,
timezone
from
functools
import
partial
from
logging
.
config
import
fileConfig
# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
from
subprocess
import
CalledProcessError
# nosec B404
from
typing
import
Dict
,
List
from
botocore
.
config
import
Config
from
common
.
schedulers
.
slurm_commands
import
(
PartitionNodelistMapping
,
from
common
.
time_utils
import
seconds
from
common
.
utils
import
check_command_output
,
read_json
,
sleep_remaining_loop_time
,
time_is_up
,
wait_remaining_time
from
retrying
import
retry
from
slurm_plugin
.
capacity_block_manager
import
CapacityBlockManager
from
slurm_plugin
.
cluster_event_publisher
import
ClusterEventPublisher
from
slurm_plugin
.
common
import
TIMESTAMP_FORMAT
,
ScalingStrategy
,
log_exception
,
print_with_count
from
slurm_plugin
.
console_logger
import
ConsoleLogger
from
slurm_plugin
.
instance_manager
import
InstanceManager
from
slurm_plugin
.
slurm_resources
import
(
ComputeResourceFailureEvent
,
from
slurm_plugin
.
task_executor
import
TaskExecutor
CONSOLE_OUTPUT_WAIT_TIME
=
5
*
60
MAXIMUM_TASK_BACKLOG
=
100
log
=
logging
.
getLogger
(
__name__
)
compute_logger
=
log
.
getChild
(
"console_output"
)
event_logger
=
log
.
getChild
(
"events"
)
class
ComputeFleetStatus
(
Enum
):
"""Represents the status of the cluster compute fleet."""
STOPPED
=
"STOPPED"
# Fleet is stopped, partitions are inactive.
RUNNING
=
"RUNNING"
# Fleet is running, partitions are active.
STOPPING
=
"STOPPING"
# clustermgtd is handling the stop request.
STARTING
=
"STARTING"
# clustermgtd is handling the start request.
STOP_REQUESTED
=
"STOP_REQUESTED"
# A request to stop the fleet has been submitted.
START_REQUESTED
=
"START_REQUESTED"
# A request to start the fleet has been submitted.
# PROTECTED indicates that some partitions have consistent bootstrap failures. Affected partitions are inactive.
return
str
(
self
.
value
)
def
is_start_requested
(
status
):
return
status
==
ComputeFleetStatus
.
START_REQUESTED
def
is_stop_requested
(
status
):
return
status
==
ComputeFleetStatus
.
STOP_REQUESTED
def
is_protected
(
status
):
return
status
==
ComputeFleetStatus
.
PROTECTED
class
ComputeFleetStatusManager
:
COMPUTE_FLEET_STATUS_ATTRIBUTE
=
"status"
COMPUTE_FLEET_LAST_UPDATED_TIME_ATTRIBUTE
=
"lastStatusUpdatedTime"
def
get_status
(
fallback
=
None
):
compute_fleet_raw_data
=
check_command_output
(
"get-compute-fleet-status.sh"
)
log
.
debug
(
"Retrieved compute fleet data: %s"
,
compute_fleet_raw_data
)
return
ComputeFleetStatus
(
json
.
loads
(
compute_fleet_raw_data
).
get
(
ComputeFleetStatusManager
.
COMPUTE_FLEET_STATUS_ATTRIBUTE
)
if
isinstance
(
e
,
CalledProcessError
):
error
=
e
.
stdout
.
rstrip
()
"Failed when retrieving fleet status with error: %s, using fallback value %s"
,
def
update_status
(
status
):
check_command_output
(
f"update-compute-fleet-status.sh --status
{
status
}
"
)
log
.
error
(
"Failed when updating fleet status to status %s, with error: %s"
,
status
,
e
.
stdout
.
rstrip
())
class
ClustermgtdConfig
:
"loop_time"
:
LOOP_TIME
,
"logging_config"
:
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"logging"
,
"parallelcluster_clustermgtd_logging.conf"
"launch_max_batch_size"
:
500
,
"assign_node_max_batch_size"
:
500
,
"update_node_address"
:
True
,
"run_instances_overrides"
:
"/opt/slurm/etc/pcluster/run_instances_overrides.json"
,
"create_fleet_overrides"
:
"/opt/slurm/etc/pcluster/create_fleet_overrides.json"
,
"fleet_config_file"
:
"/etc/parallelcluster/slurm_plugin/fleet-config.json"
,
"terminate_max_batch_size"
:
1000
,
# Timeout to wait for node initialization, should be the same as ResumeTimeout
"node_replacement_timeout"
:
1800
,
"terminate_drain_nodes"
:
True
,
"terminate_down_nodes"
:
True
,
"orphaned_instance_timeout"
:
300
,
"ec2_instance_missing_max_count"
:
0
,
"disable_ec2_health_check"
:
False
,
"disable_scheduled_event_health_check"
:
False
,
"disable_all_cluster_management"
:
False
,
"health_check_timeout"
:
180
,
"health_check_timeout_after_slurmdstarttime"
:
180
,
"disable_capacity_blocks_management"
:
False
,
"use_private_hostname"
:
False
,
"protected_failure_count"
:
10
,
"insufficient_capacity_timeout"
:
600
,
# Compute console logging configs
"compute_console_logging_enabled"
:
True
,
"compute_console_logging_max_sample_size"
:
1
,
"compute_console_wait_time"
:
CONSOLE_OUTPUT_WAIT_TIME
,
"worker_pool_max_backlog"
:
MAXIMUM_TASK_BACKLOG
,
def
__init__
(
self
,
config_file_path
):
self
.
_get_config
(
config_file_path
)
attrs
=
", "
.
join
([
"{key}={value}"
.
format
(
key
=
key
,
value
=
repr
(
value
))
for
key
,
value
in
self
.
__dict__
.
items
()])
return
"{class_name}({attrs})"
.
format
(
class_name
=
self
.
__class__
.
__name__
,
attrs
=
attrs
)
def
__eq__
(
self
,
other
):
if
type
(
other
)
is
type
(
self
):
self
.
_config
==
other
.
_config
and
self
.
fleet_config
==
other
.
fleet_config
and
self
.
run_instances_overrides
==
other
.
run_instances_overrides
and
self
.
create_fleet_overrides
==
other
.
create_fleet_overrides
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
_get_basic_config
(
self
,
config
):
"""Get basic config options."""
self
.
region
=
config
.
get
(
"clustermgtd"
,
"region"
)
self
.
cluster_name
=
config
.
get
(
"clustermgtd"
,
"cluster_name"
)
self
.
dynamodb_table
=
config
.
get
(
"clustermgtd"
,
"dynamodb_table"
)
self
.
head_node_private_ip
=
config
.
get
(
"clustermgtd"
,
"head_node_private_ip"
)
self
.
head_node_hostname
=
config
.
get
(
"clustermgtd"
,
"head_node_hostname"
)
self
.
head_node_instance_id
=
config
.
get
(
"clustermgtd"
,
"instance_id"
,
fallback
=
"unknown"
)
# Configure boto3 to retry 1 times by default
self
.
_boto3_retry
=
config
.
getint
(
"clustermgtd"
,
"boto3_retry"
,
fallback
=
self
.
DEFAULTS
.
get
(
"max_retry"
))
self
.
_boto3_config
=
{
"retries"
: {
"max_attempts"
:
self
.
_boto3_retry
,
"mode"
:
"standard"
}}
self
.
loop_time
=
config
.
getint
(
"clustermgtd"
,
"loop_time"
,
fallback
=
self
.
DEFAULTS
.
get
(
"loop_time"
))
self
.
disable_all_cluster_management
=
config
.
getboolean
(
"disable_all_cluster_management"
,
fallback
=
self
.
DEFAULTS
.
get
(
"disable_all_cluster_management"
),
self
.
heartbeat_file_path
=
config
.
get
(
"clustermgtd"
,
"heartbeat_file_path"
)
proxy
=
config
.
get
(
"clustermgtd"
,
"proxy"
,
fallback
=
self
.
DEFAULTS
.
get
(
"proxy"
))
self
.
_boto3_config
[
"proxies"
]
=
{
"https"
:
proxy
}
self
.
boto3_config
=
Config
(
**
self
.
_boto3_config
)
self
.
logging_config
=
config
.
get
(
"clustermgtd"
,
"logging_config"
,
fallback
=
self
.
DEFAULTS
.
get
(
"logging_config"
))
def
_get_launch_config
(
self
,
config
):
"""Get config options related to launching instances."""
self
.
launch_max_batch_size
=
config
.
getint
(
"clustermgtd"
,
"launch_max_batch_size"
,
fallback
=
self
.
DEFAULTS
.
get
(
"launch_max_batch_size"
)
self
.
assign_node_max_batch_size
=
config
.
getint
(
"clustermgtd"
,
"assign_node_max_batch_size"
,
fallback
=
self
.
DEFAULTS
.
get
(
"assign_node_max_batch_size"
)
self
.
update_node_address
=
config
.
getboolean
(
"clustermgtd"
,
"update_node_address"
,
fallback
=
self
.
DEFAULTS
.
get
(
"update_node_address"
)
fleet_config_file
=
config
.
get
(
"clustermgtd"
,
"fleet_config_file"
,
fallback
=
self
.
DEFAULTS
.
get
(
"fleet_config_file"
)
self
.
fleet_config
=
read_json
(
fleet_config_file
)
# run_instances_overrides_file and create_fleet_overrides_file contain a json with the following format:
# "compute_resource_name": {
# <arbitrary-json-with-boto3-api-params-to-override>
run_instances_overrides_file
=
config
.
get
(
"clustermgtd"
,
"run_instances_overrides"
,
fallback
=
self
.
DEFAULTS
.
get
(
"run_instances_overrides"
)
self
.
run_instances_overrides
=
read_json
(
run_instances_overrides_file
,
default
=
{})
create_fleet_overrides_file
=
config
.
get
(
"clustermgtd"
,
"create_fleet_overrides"
,
fallback
=
self
.
DEFAULTS
.
get
(
"create_fleet_overrides"
)
self
.
create_fleet_overrides
=
read_json
(
create_fleet_overrides_file
,
default
=
{})
def
_get_health_check_config
(
self
,
config
):
self
.
disable_ec2_health_check
=
config
.
getboolean
(
"clustermgtd"
,
"disable_ec2_health_check"
,
fallback
=
self
.
DEFAULTS
.
get
(
"disable_ec2_health_check"
)
self
.
disable_scheduled_event_health_check
=
config
.
getboolean
(
"disable_scheduled_event_health_check"
,
fallback
=
self
.
DEFAULTS
.
get
(
"disable_scheduled_event_health_check"
),
self
.
health_check_timeout
=
config
.
getint
(
"clustermgtd"
,
"health_check_timeout"
,
fallback
=
self
.
DEFAULTS
.
get
(
"health_check_timeout"
)
self
.
health_check_timeout_after_slurmdstarttime
=
config
.
getint
(
"health_check_timeout_after_slurmdstarttime"
,
fallback
=
self
.
DEFAULTS
.
get
(
"health_check_timeout_after_slurmdstarttime"
),
self
.
disable_all_health_checks
=
config
.
getboolean
(
"disable_all_health_checks"
,
fallback
=
(
self
.
disable_ec2_health_check
and
self
.
disable_scheduled_event_health_check
),
self
.
disable_capacity_blocks_management
=
config
.
getboolean
(
"disable_capacity_blocks_management"
,
fallback
=
self
.
DEFAULTS
.
get
(
"disable_capacity_block_management"
),
def
_get_terminate_config
(
self
,
config
):
"""Get config option related to instance termination and node replacement."""
self
.
terminate_max_batch_size
=
config
.
getint
(
"clustermgtd"
,
"terminate_max_batch_size"
,
fallback
=
self
.
DEFAULTS
.
get
(
"terminate_max_batch_size"
)
self
.
node_replacement_timeout
=
config
.
getint
(
"clustermgtd"
,
"node_replacement_timeout"
,
fallback
=
self
.
DEFAULTS
.
get
(
"node_replacement_timeout"
)
self
.
terminate_drain_nodes
=
config
.
getboolean
(
"clustermgtd"
,
"terminate_drain_nodes"
,
fallback
=
self
.
DEFAULTS
.
get
(
"terminate_drain_nodes"
)
self
.
terminate_down_nodes
=
config
.
getboolean
(
"clustermgtd"
,
"terminate_down_nodes"
,
fallback
=
self
.
DEFAULTS
.
get
(
"terminate_down_nodes"
)
self
.
orphaned_instance_timeout
=
config
.
getint
(
"clustermgtd"
,
"orphaned_instance_timeout"
,
fallback
=
self
.
DEFAULTS
.
get
(
"orphaned_instance_timeout"
)
self
.
protected_failure_count
=
config
.
getint
(
"clustermgtd"
,
"protected_failure_count"
,
fallback
=
self
.
DEFAULTS
.
get
(
"protected_failure_count"
)
self
.
insufficient_capacity_timeout
=
config
.
getfloat
(
"clustermgtd"
,
"insufficient_capacity_timeout"
,
fallback
=
self
.
DEFAULTS
.
get
(
"insufficient_capacity_timeout"
)
self
.
ec2_instance_missing_max_count
=
config
.
getint
(
"ec2_instance_missing_max_count"
,
fallback
=
self
.
DEFAULTS
.
get
(
"ec2_instance_missing_max_count"
),
self
.
disable_nodes_on_insufficient_capacity
=
self
.
insufficient_capacity_timeout
>
0
def
_get_dns_config
(
self
,
config
):
"""Get config option related to Route53 DNS domain."""
self
.
hosted_zone
=
config
.
get
(
"clustermgtd"
,
"hosted_zone"
,
fallback
=
self
.
DEFAULTS
.
get
(
"hosted_zone"
))
self
.
dns_domain
=
config
.
get
(
"clustermgtd"
,
"dns_domain"
,
fallback
=
self
.
DEFAULTS
.
get
(
"dns_domain"
))
self
.
use_private_hostname
=
config
.
getboolean
(
"clustermgtd"
,
"use_private_hostname"
,
fallback
=
self
.
DEFAULTS
.
get
(
"use_private_hostname"
)
def
_get_compute_console_output_config
(
self
,
config
):
"""Get config options related to logging console output from compute nodes."""
self
.
compute_console_logging_enabled
=
config
.
getboolean
(
"compute_console_logging_enabled"
,
fallback
=
self
.
DEFAULTS
.
get
(
"compute_console_logging_enabled"
),
self
.
compute_console_logging_max_sample_size
=
config
.
getint
(
"compute_console_logging_max_sample_size"
,
fallback
=
self
.
DEFAULTS
.
get
(
"compute_console_logging_max_sample_size"
),
self
.
compute_console_wait_time
=
config
.
getint
(
"clustermgtd"
,
"compute_console_wait_time"
,
fallback
=
self
.
DEFAULTS
.
get
(
"compute_console_wait_time"
)
def
_get_worker_pool_config
(
self
,
config
):
self
.
worker_pool_size
=
config
.
getint
(
"clustermgtd"
,
"worker_pool_size"
,
fallback
=
self
.
DEFAULTS
.
get
(
"worker_pool_size"
)
self
.
worker_pool_max_backlog
=
config
.
getint
(
"clustermgtd"
,
"worker_pool_max_backlog"
,
fallback
=
self
.
DEFAULTS
.
get
(
"worker_pool_max_backlog"
)
@
log_exception
(
log
,
"reading cluster manager configuration file"
,
catch_exception
=
IOError
,
raise_on_error
=
True
)
def
_get_config
(
self
,
config_file_path
):
"""Get clustermgtd configuration."""
log
.
info
(
"Reading %s"
,
config_file_path
)
self
.
_config
=
ConfigParser
()
self
.
_config
.
read_file
(
open
(
config_file_path
,
"r"
))
self
.
_get_basic_config
(
self
.
_config
)
self
.
_get_health_check_config
(
self
.
_config
)
self
.
_get_launch_config
(
self
.
_config
)
self
.
_get_terminate_config
(
self
.
_config
)
self
.
_get_dns_config
(
self
.
_config
)
self
.
_get_compute_console_output_config
(
self
.
_config
)
self
.
_get_worker_pool_config
(
self
.
_config
)
"""Class for all cluster management related actions."""
_config
:
ClustermgtdConfig
class
HealthCheckTypes
(
Enum
):
"""Enum for health check types."""
scheduled_event
=
"scheduled_events_check"
ec2_health
=
"ec2_health_check"
class
EC2InstancesInfoUnavailable
(
Exception
):
"""Exception raised when unable to retrieve cluster instance info from EC2."""
def
__init__
(
self
,
config
):
Initialize ClusterManager.
self.static_nodes_in_replacement is persistent across multiple iteration of manage_cluster
This state is required because we need to ignore static nodes that might have long bootstrap time
self
.
_insufficient_capacity_compute_resources
=
{}
self
.
_static_nodes_in_replacement
=
set
()
self
.
_partitions_protected_failure_count_map
=
{}
self
.
_nodes_without_backing_instance_count_map
=
{}
self
.
_compute_fleet_status
=
ComputeFleetStatus
.
RUNNING
self
.
_current_time
=
None
self
.
_compute_fleet_status_manager
=
None
self
.
_instance_manager
=
None
self
.
_task_executor
=
None
self
.
_console_logger
=
None
self
.
_event_publisher
=
None
self
.
_partition_nodelist_mapping_instance
=
None
self
.
_capacity_block_manager
=
None
self
.
set_config
(
config
)
def
set_config
(
self
,
config
:
ClustermgtdConfig
):
if
self
.
_config
!=
config
:
log
.
info
(
"Applying new clustermgtd config: %s"
,
config
)
# If a new task executor is needed, the old one will be shutdown.
# This should cause any pending tasks to be cancelled (Py 3.9) and
# any executing tasks to raise an exception as long as they test
# the executor for shutdown. If tracking of cancelled or failed
# tasks is needed, then the Future object returned from task_executor
# can be queried to determine how the task exited.
# The shutdown on the task_executor is by default, non-blocking, so
# it is possible for some tasks to continue executing even after the
# shutdown request has returned.
self
.
_task_executor
=
self
.
_initialize_executor
(
config
)
self
.
_event_publisher
=
ClusterEventPublisher
.
create_with_default_publisher
(
event_logger
,
config
.
cluster_name
,
"HeadNode"
,
"clustermgtd"
,
config
.
head_node_instance_id
self
.
_compute_fleet_status_manager
=
ComputeFleetStatusManager
()
self
.
_instance_manager
=
self
.
_initialize_instance_manager
(
config
)
self
.
_console_logger
=
self
.
_initialize_console_logger
(
config
)
self
.
_capacity_block_manager
=
self
.
_initialize_capacity_block_manager
(
config
)
if
self
.
_task_executor
:
self
.
_task_executor
.
shutdown
()
self
.
_task_executor
=
None
def
_initialize_instance_manager
(
config
):
"""Initialize instance manager class that will be used to launch/terminate/describe instances."""
table_name
=
config
.
dynamodb_table
,
hosted_zone
=
config
.
hosted_zone
,
dns_domain
=
config
.
dns_domain
,
use_private_hostname
=
config
.
use_private_hostname
,
head_node_private_ip
=
config
.
head_node_private_ip
,
head_node_hostname
=
config
.
head_node_hostname
,
run_instances_overrides
=
config
.
run_instances_overrides
,
create_fleet_overrides
=
config
.
create_fleet_overrides
,
fleet_config
=
config
.
fleet_config
,
def
_initialize_executor
(
self
,
config
):
or
self
.
_config
.
worker_pool_size
!=
config
.
worker_pool_size
or
self
.
_config
.
worker_pool_max_backlog
!=
config
.
worker_pool_max_backlog
if
self
.
_task_executor
:
self
.
_task_executor
.
shutdown
()
worker_pool_size
=
config
.
worker_pool_size
,
max_backlog
=
config
.
worker_pool_max_backlog
,
return
self
.
_task_executor
def
_initialize_console_logger
(
config
):
region
=
config
.
region
,
enabled
=
config
.
compute_console_logging_enabled
,
console_output_consumer
=
lambda
name
,
instance_id
,
output
:
compute_logger
.
info
(
"Console output for node %s (Instance Id %s):
\r
%s"
,
name
,
instance_id
,
output
def
_initialize_capacity_block_manager
(
config
):
return
CapacityBlockManager
(
region
=
config
.
region
,
fleet_config
=
config
.
fleet_config
,
boto3_config
=
config
.
boto3_config
def
_update_compute_fleet_status
(
self
,
status
):
log
.
info
(
"Updating compute fleet status from %s to %s"
,
self
.
_compute_fleet_status
,
status
)
self
.
_compute_fleet_status_manager
.
update_status
(
status
)
self
.
_compute_fleet_status
=
status
def
_handle_successfully_launched_nodes
(
self
,
partitions_name_map
):
Handle nodes have been failed in bootstrap are launched successfully during this iteration.
Includes resetting partition bootstrap failure count for these nodes type and updating the
bootstrap_failure_nodes_types.
If a node has been failed successfully launched, the partition count of this node will be reset.
So the successfully launched nodes type will be removed from bootstrap_failure_nodes_types.
If there's node types failed in this partition later, it will keep count again.
# Find nodes types which have been failed during bootstrap now become online.
partitions_protected_failure_count_map
=
self
.
_partitions_protected_failure_count_map
.
copy
()
for
partition
,
failures_per_compute_resource
in
partitions_protected_failure_count_map
.
items
():
partition_online_compute_resources
=
partitions_name_map
[
partition
].
get_online_node_by_type
(
self
.
_config
.
terminate_drain_nodes
,
self
.
_config
.
terminate_down_nodes
,
self
.
_config
.
ec2_instance_missing_max_count
,
self
.
_nodes_without_backing_instance_count_map
,
for
compute_resource
in
failures_per_compute_resource
.
keys
():
if
compute_resource
in
partition_online_compute_resources
:
self
.
_reset_partition_failure_count
(
partition
)
def
manage_cluster
(
self
):
"""Manage cluster by syncing scheduler states with EC2 states and performing node maintenance actions."""
log
.
info
(
"Managing cluster..."
)
self
.
_current_time
=
datetime
.
now
(
tz
=
timezone
.
utc
)
self
.
_compute_fleet_status
=
self
.
_compute_fleet_status_manager
.
get_status
(
fallback
=
self
.
_compute_fleet_status
)
log
.
info
(
"Current compute fleet status: %s"
,
self
.
_compute_fleet_status
)
if
not
self
.
_config
.
disable_all_cluster_management
:
if
self
.
_compute_fleet_status
in
{
ComputeFleetStatus
.
RUNNING
,
ComputeFleetStatus
.
PROTECTED
,
# Get partition_nodelist_mapping between PC-managed Slurm partitions and PC-managed Slurm nodelists
# Initialize PartitionNodelistMapping singleton
self
.
_partition_nodelist_mapping_instance
=
PartitionNodelistMapping
.
instance
()
# Get node states for nodes in inactive and active partitions
log
.
info
(
"Retrieving nodes info from the scheduler"
)
nodes
=
self
.
_get_node_info_with_retry
()
log
.
debug
(
"Nodes: %s"
,
nodes
)
partitions_name_map
,
compute_resource_nodes_map
=
self
.
_parse_scheduler_nodes_data
(
nodes
)
"Unable to get partition/node info from slurm, no other action can be performed. Sleeping... "
# Get all non-terminating instances in EC2
cluster_instances
=
self
.
_get_ec2_instances
()
except
ClusterManager
.
EC2InstancesInfoUnavailable
:
log
.
error
(
"Unable to get instances info from EC2, no other action can be performed. Sleeping..."
)
log
.
debug
(
"Current cluster instances in EC2: %s"
,
cluster_instances
)
partitions
=
list
(
partitions_name_map
.
values
())
self
.
_update_slurm_nodes_with_ec2_info
(
nodes
,
cluster_instances
)
self
.
_event_publisher
.
publish_compute_node_events
(
nodes
,
cluster_instances
)
# Handle inactive partition and terminate backing instances
self
.
_clean_up_inactive_partition
(
partitions
)
# Perform health check actions
if
not
self
.
_config
.
disable_all_health_checks
:
self
.
_perform_health_check_actions
(
partitions
)
self
.
_maintain_nodes
(
partitions_name_map
,
compute_resource_nodes_map
)
# Clean up orphaned instances
self
.
_terminate_orphaned_instances
(
cluster_instances
)
elif
self
.
_compute_fleet_status
in
{
ComputeFleetStatus
.
STOPPED
,
# Since Slurm partition status might have been manually modified, when STOPPED we want to keep checking
# partitions and EC2 instances to take into account changes that can be manually
# applied by the user by re-activating Slurm partitions.
# When partition are INACTIVE, always try to reset nodeaddr/nodehostname to avoid issue.
self
.
_maintain_nodes_down
()
# Write clustermgtd heartbeat to file
self
.
_write_timestamp_to_file
()
def
_write_timestamp_to_file
(
self
):
"""Write timestamp into shared file so compute nodes can determine if head node is online."""
# Make clustermgtd heartbeat readable to all users
with
open
(
os
.
open
(
self
.
_config
.
heartbeat_file_path
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o644
),
"w"
)
as
timestamp_file
:
# Note: heartbeat must be written with datetime.strftime to convert localized datetime into str
# datetime.strptime will not work with str(datetime)
timestamp_file
.
write
(
datetime
.
now
(
tz
=
timezone
.
utc
).
strftime
(
TIMESTAMP_FORMAT
))
@
retry
(
stop_max_attempt_number
=
2
,
wait_fixed
=
1000
)
def
_get_node_info_with_retry
(
nodes
=
""
):
return
get_nodes_info
(
nodes
)
@
retry
(
stop_max_attempt_number
=
2
,
wait_fixed
=
1000
)
def
_get_partitions_info_with_retry
()
->
Dict
[
str
,
SlurmPartition
]:
return
{
part
.
name
:
part
for
part
in
get_partitions_info
()}
def
_clean_up_inactive_partition
(
self
,
partitions
):
"""Terminate all other instances associated with nodes in INACTIVE partition directly through EC2."""
inactive_instance_ids
,
inactive_nodes
=
ClusterManager
.
_get_inactive_instances_and_nodes
(
partitions
)
log
.
info
(
"Cleaning up INACTIVE partitions."
)
if
inactive_instance_ids
:
"Clean up instances associated with nodes in INACTIVE partitions: %s"
,
print_with_count
(
inactive_instance_ids
),
self
.
_instance_manager
.
delete_instances
(
inactive_instance_ids
,
terminate_batch_size
=
self
.
_config
.
terminate_max_batch_size
self
.
_reset_nodes_in_inactive_partitions
(
list
(
inactive_nodes
))
log
.
error
(
"Failed to clean up INACTIVE nodes %s with exception %s"
,
print_with_count
(
inactive_nodes
),
e
)
def
_reset_nodes_in_inactive_partitions
(
inactive_nodes
):
# Try to reset nodeaddr if possible to avoid potential problems
for
node
in
inactive_nodes
:
if
node
.
needs_reset_when_inactive
():
nodes_to_reset
.
add
(
node
.
name
)
# Setting to down and not power_down cause while inactive power_down doesn't seem to be applied
"Resetting nodeaddr/nodehostname and setting to down the following nodes: %s"
,
print_with_count
(
nodes_to_reset
),
reason
=
"inactive partition"
,
"Encountered exception when resetting nodeaddr for INACTIVE nodes %s: %s"
,
print_with_count
(
nodes_to_reset
),
def
_get_ec2_instances
(
self
):
Get EC2 instance by describe_instances API.
Call is made by filtering on tags and includes non-terminating instances only
Instances returned will not contain instances previously terminated in _clean_up_inactive_partition
# After reading Slurm nodes wait for 5 seconds to let instances appear in EC2 describe_instances call
log
.
info
(
"Retrieving list of EC2 instances associated with the cluster"
)
return
self
.
_instance_manager
.
get_cluster_instances
(
include_head_node
=
False
,
alive_states_only
=
True
)
log
.
error
(
"Failed when getting instance info from EC2 with exception %s"
,
e
)
raise
ClusterManager
.
EC2InstancesInfoUnavailable
@
log_exception
(
log
,
"maintaining slurm nodes down"
,
catch_exception
=
Exception
,
raise_on_error
=
False
)
def
_maintain_nodes_down
(
self
):
update_all_partitions
(
PartitionStatus
.
INACTIVE
,
reset_node_addrs_hostname
=
True
)
self
.
_instance_manager
.
terminate_all_compute_nodes
(
self
.
_config
.
terminate_max_batch_size
)
@
log_exception
(
log
,
"performing health check action"
,
catch_exception
=
Exception
,
raise_on_error
=
False
)
def
_perform_health_check_actions
(
self
,
partitions
:
List
[
SlurmPartition
]):
"""Run health check actions."""
log
.
info
(
"Performing instance health check actions"
)
instance_id_to_active_node_map
=
ClusterManager
.
get_instance_id_to_active_node_map
(
partitions
)
if
not
instance_id_to_active_node_map
:
# Get health states for instances that might be considered unhealthy
unhealthy_instances_status
=
self
.
_instance_manager
.
get_unhealthy_cluster_instance_status
(
list
(
instance_id_to_active_node_map
.
keys
())
log
.
debug
(
"Cluster instances that might be considered unhealthy: %s"
,
unhealthy_instances_status
)
if
unhealthy_instances_status
:
# Perform EC2 health check actions
if
not
self
.
_config
.
disable_ec2_health_check
:
self
.
_handle_health_check
(
unhealthy_instances_status
,
instance_id_to_active_node_map
,
health_check_type
=
ClusterManager
.
HealthCheckTypes
.
ec2_health
,
# Perform scheduled event actions
if
not
self
.
_config
.
disable_scheduled_event_health_check
:
self
.
_handle_health_check
(
unhealthy_instances_status
,
instance_id_to_active_node_map
,
health_check_type
=
ClusterManager
.
HealthCheckTypes
.
scheduled_event
,
def
_get_nodes_failing_health_check
(
self
,
unhealthy_instances_status
,
instance_id_to_active_node_map
,
health_check_type
"""Get nodes fail health check."""
log
.
info
(
"Performing actions for health check type: %s"
,
health_check_type
)
nodes_failing_health_check
=
[]
for
instance_status
in
unhealthy_instances_status
:
unhealthy_node
=
instance_id_to_active_node_map
.
get
(
instance_status
.
id
)
if
unhealthy_node
and
self
.
_is_instance_unhealthy
(
instance_status
,
health_check_type
):
nodes_failing_health_check
.
append
(
unhealthy_node
)
unhealthy_node
.
is_failing_health_check
=
True
"Node %s(%s) is associated with instance %s that is failing %s. EC2 health state: %s"
,
unhealthy_node
.
nodeaddr
,
instance_status
.
state
,
instance_status
.
instance_status
,
instance_status
.
system_status
,
instance_status
.
scheduled_events
,
return
nodes_failing_health_check
def
_is_instance_unhealthy
(
self
,
instance_status
:
EC2InstanceHealthState
,
health_check_type
):
"""Check if instance status is unhealthy based on picked corresponding health check function."""
is_instance_status_unhealthy
=
False
if
health_check_type
==
ClusterManager
.
HealthCheckTypes
.
scheduled_event
:
is_instance_status_unhealthy
=
instance_status
.
fail_scheduled_events_check
()
elif
health_check_type
==
ClusterManager
.
HealthCheckTypes
.
ec2_health
:
is_instance_status_unhealthy
=
instance_status
.
fail_ec2_health_check
(
self
.
_current_time
,
self
.
_config
.
health_check_timeout
return
is_instance_status_unhealthy
def
_update_static_nodes_in_replacement
(
self
,
slurm_nodes
:
List
[
SlurmNode
]):
"""Remove from self.static_nodes_in_replacement nodes that are up or that are in maintenance."""
nodename_to_slurm_nodes_map
=
{
node
.
name
:
node
for
node
in
slurm_nodes
}
nodes_still_in_replacement
=
set
()
for
nodename
in
self
.
_static_nodes_in_replacement
:
node
=
nodename_to_slurm_nodes_map
.
get
(
nodename
)
# Consider nodename still in replacement if node is not up and not in maintenance
if
node
and
not
node
.
is_up
()
and
not
node
.
is_in_maintenance
():
nodes_still_in_replacement
.
add
(
nodename
)
# override self._static_nodes_in_replacement with updated list
self
.
_static_nodes_in_replacement
=
nodes_still_in_replacement
for
node
in
slurm_nodes
:
node
.
is_static_nodes_in_replacement
=
node
.
name
in
self
.
_static_nodes_in_replacement
node
.
is_being_replaced
=
self
.
_is_node_being_replaced
(
node
)
node
.
_is_replacement_timeout
=
self
.
_is_node_replacement_timeout
(
node
)
def
_find_unhealthy_slurm_nodes
(
self
,
slurm_nodes
):
Find unhealthy static slurm nodes and dynamic slurm nodes.
Check and return slurm nodes with unhealthy and healthy scheduler state, grouping unhealthy nodes
by node type (static/dynamic).
unhealthy_static_nodes
=
[]
unhealthy_dynamic_nodes
=
[]
ice_compute_resources_and_nodes_map
=
{}
# Remove the nodes part of inactive Capacity Blocks from the list of unhealthy nodes.
# Nodes from active Capacity Blocks will be instead managed as unhealthy instances.
if
not
self
.
_config
.
disable_capacity_blocks_management
:
reserved_nodenames
=
self
.
_capacity_block_manager
.
get_reserved_nodenames
(
slurm_nodes
)
"The nodes associated with inactive Capacity Blocks and not considered as unhealthy nodes are: %s"
,
","
.
join
(
reserved_nodenames
),
log
.
debug
(
"No nodes found associated with inactive Capacity Blocks."
)
for
node
in
slurm_nodes
:
if
not
node
.
is_healthy
(
consider_drain_as_unhealthy
=
self
.
_config
.
terminate_drain_nodes
,
consider_down_as_unhealthy
=
self
.
_config
.
terminate_down_nodes
,
ec2_instance_missing_max_count
=
self
.
_config
.
ec2_instance_missing_max_count
,
nodes_without_backing_instance_count_map
=
self
.
_nodes_without_backing_instance_count_map
,
log_warn_if_unhealthy
=
node
.
name
not
in
reserved_nodenames
,
if
not
self
.
_config
.
disable_capacity_blocks_management
and
node
.
name
in
reserved_nodenames
:
# do not consider as unhealthy the nodes reserved for capacity blocks
all_unhealthy_nodes
.
append
(
node
)
if
isinstance
(
node
,
StaticNode
):
unhealthy_static_nodes
.
append
(
node
)
elif
self
.
_config
.
disable_nodes_on_insufficient_capacity
and
node
.
is_ice
():
ice_compute_resources_and_nodes_map
.
setdefault
(
node
.
queue_name
, {}).
setdefault
(
node
.
compute_resource_name
, []
unhealthy_dynamic_nodes
.
append
(
node
)
self
.
_event_publisher
.
publish_unhealthy_node_events
(
self
.
_config
.
ec2_instance_missing_max_count
,
self
.
_nodes_without_backing_instance_count_map
,
unhealthy_dynamic_nodes
,
ice_compute_resources_and_nodes_map
,
def
_increase_partitions_protected_failure_count
(
self
,
bootstrap_failure_nodes
):
"""Keep count of boostrap failures."""
for
node
in
bootstrap_failure_nodes
:
compute_resource
=
node
.
compute_resource_name
queue_name
=
node
.
queue_name
if
queue_name
in
self
.
_partitions_protected_failure_count_map
:
self
.
_partitions_protected_failure_count_map
[
queue_name
][
compute_resource
]
=
(
self
.
_partitions_protected_failure_count_map
[
queue_name
].
get
(
compute_resource
,
0
)
+
1
self
.
_partitions_protected_failure_count_map
[
queue_name
]
=
{}
self
.
_partitions_protected_failure_count_map
[
queue_name
][
compute_resource
]
=
1
@
log_exception
(
log
,
"maintaining unhealthy dynamic nodes"
,
raise_on_error
=
False
)
def
_handle_unhealthy_dynamic_nodes
(
self
,
unhealthy_dynamic_nodes
):
Maintain any unhealthy dynamic node.
Terminate instances backing dynamic nodes.
Setting node to down will let slurm requeue jobs allocated to node.
Setting node to power_down will terminate backing instance and reset dynamic node for future use.
instances_to_terminate
=
[
node
.
instance
.
id
for
node
in
unhealthy_dynamic_nodes
if
node
.
instance
]
if
instances_to_terminate
:
log
.
info
(
"Terminating instances that are backing unhealthy dynamic nodes"
)
self
.
_instance_manager
.
delete_instances
(
instances_to_terminate
,
terminate_batch_size
=
self
.
_config
.
terminate_max_batch_size
log
.
info
(
"Setting unhealthy dynamic nodes to down and power_down."
)
set_nodes_power_down
([
node
.
name
for
node
in
unhealthy_dynamic_nodes
],
reason
=
"Scheduler health check failed"
)
@
log_exception
(
log
,
"maintaining powering down nodes"
,
raise_on_error
=
False
)
def
_handle_powering_down_nodes
(
self
,
slurm_nodes
):
Handle nodes that are powering down and not already being replaced.
Terminate instances backing the powering down node if any.
Reset the nodeaddr for the powering down node. Node state is not changed.
node
for
node
in
slurm_nodes
if
node
.
is_powering_down_with_nodeaddr
()
and
not
node
.
is_being_replaced
log
.
info
(
"Resetting powering down nodes: %s"
,
print_with_count
(
powering_down_nodes
))
reset_nodes
(
nodes
=
[
node
.
name
for
node
in
powering_down_nodes
])
instances_to_terminate
=
[
node
.
instance
.
id
for
node
in
powering_down_nodes
if
node
.
instance
]
log
.
info
(
"Terminating instances that are backing powering down nodes"
)
self
.
_instance_manager
.
delete_instances
(
instances_to_terminate
,
terminate_batch_size
=
self
.
_config
.
terminate_max_batch_size
@
log_exception
(
log
,
"maintaining unhealthy static nodes"
,
raise_on_error
=
False
)
def
_handle_unhealthy_static_nodes
(
self
,
unhealthy_static_nodes
):
Maintain any unhealthy static node.
Set node to down, terminate backing instance, and launch new instance for static node.
wait_function
=
partial
(
wait_start_time
=
datetime
.
now
(
tz
=
timezone
.
utc
),
total_wait_time
=
self
.
_config
.
compute_console_wait_time
,
wait_function
=
self
.
_task_executor
.
wait_unless_shutdown
,
self
.
_console_logger
.
report_console_output_from_nodes
(
compute_instances
=
self
.
_instance_manager
.
get_compute_node_instances
(
self
.
_config
.
compute_console_logging_max_sample_size
,
task_controller
=
self
.
_task_executor
,
task_wait_function
=
wait_function
,
log
.
error
(
"Encountered exception when retrieving console output from unhealthy static nodes: %s"
,
e
)
node_list
=
[
node
.
name
for
node
in
unhealthy_static_nodes
]
# Set nodes into down state so jobs can be requeued immediately
log
.
info
(
"Setting unhealthy static nodes to DOWN"
)
reset_nodes
(
node_list
,
state
=
"down"
,
reason
=
"Static node maintenance: unhealthy node is being replaced"
)
log
.
error
(
"Encountered exception when setting unhealthy static nodes into down state: %s"
,
e
)
instances_to_terminate
=
[
node
.
instance
.
id
for
node
in
unhealthy_static_nodes
if
node
.
instance
]
if
instances_to_terminate
:
log
.
info
(
"Terminating instances backing unhealthy static nodes"
)
self
.
_instance_manager
.
delete_instances
(
instances_to_terminate
,
terminate_batch_size
=
self
.
_config
.
terminate_max_batch_size
log
.
info
(
"Launching new instances for unhealthy static nodes"
)
self
.
_instance_manager
.
add_instances
(
launch_batch_size
=
self
.
_config
.
launch_max_batch_size
,
assign_node_batch_size
=
self
.
_config
.
assign_node_max_batch_size
,
update_node_address
=
self
.
_config
.
update_node_address
,
scaling_strategy
=
ScalingStrategy
.
BEST_EFFORT
,
# Add launched nodes to list of nodes being replaced, excluding any nodes that failed to launch
failed_nodes
=
set
().
union
(
*
self
.
_instance_manager
.
failed_nodes
.
values
())
launched_nodes
=
set
(
node_list
)
-
failed_nodes
self
.
_static_nodes_in_replacement
|=
launched_nodes
"After node maintenance, following nodes are currently in replacement: %s"
,
print_with_count
(
self
.
_static_nodes_in_replacement
),
self
.
_event_publisher
.
publish_unhealthy_static_node_events
(
self
.
_static_nodes_in_replacement
,
self
.
_instance_manager
.
failed_nodes
,
@
log_exception
(
log
,
"maintaining slurm nodes"
,
catch_exception
=
Exception
,
raise_on_error
=
False
)
def
_maintain_nodes
(
self
,
partitions_name_map
,
compute_resource_nodes_map
):
Call functions to maintain unhealthy nodes.
This function needs to handle the case that 2 slurm nodes have the same IP/nodeaddr.
A list of slurm nodes is passed in and slurm node map with IP/nodeaddr as key should be avoided.
log
.
info
(
"Performing node maintenance actions"
)
# Retrieve nodes from Slurm partitions in ACTIVE state
active_nodes
=
self
.
_find_active_nodes
(
partitions_name_map
)
# Update self.static_nodes_in_replacement by removing from the set any node that is up or in maintenance
self
.
_update_static_nodes_in_replacement
(
active_nodes
)
"Following nodes are currently in replacement: %s"
,
print_with_count
(
self
.
_static_nodes_in_replacement
)
# terminate powering down instances
self
.
_handle_powering_down_nodes
(
active_nodes
)
# retrieve and manage unhealthy nodes
unhealthy_dynamic_nodes
,
ice_compute_resources_and_nodes_map
,
)
=
self
.
_find_unhealthy_slurm_nodes
(
active_nodes
)
if
unhealthy_dynamic_nodes
:
log
.
info
(
"Found the following unhealthy dynamic nodes: %s"
,
print_with_count
(
unhealthy_dynamic_nodes
))
self
.
_handle_unhealthy_dynamic_nodes
(
unhealthy_dynamic_nodes
)
if
unhealthy_static_nodes
:
log
.
info
(
"Found the following unhealthy static nodes: %s"
,
print_with_count
(
unhealthy_static_nodes
))
self
.
_handle_unhealthy_static_nodes
(
unhealthy_static_nodes
)
# evaluate partitions to put in protected mode and ICEs nodes to terminate
if
self
.
_is_protected_mode_enabled
():
self
.
_handle_protected_mode_process
(
active_nodes
,
partitions_name_map
)
if
self
.
_config
.
disable_nodes_on_insufficient_capacity
:
self
.
_handle_ice_nodes
(
ice_compute_resources_and_nodes_map
,
compute_resource_nodes_map
)
self
.
_handle_failed_health_check_nodes_in_replacement
(
active_nodes
)
@
log_exception
(
log
,
"terminating orphaned instances"
,
catch_exception
=
Exception
,
raise_on_error
=
False
)
def
_terminate_orphaned_instances
(
self
,
cluster_instances
):
"""Terminate instance not associated with any node and running longer than orphaned_instance_timeout."""
log
.
info
(
"Checking for orphaned instance"
)
instances_to_terminate
=
[]
for
instance
in
cluster_instances
:
if
not
instance
.
slurm_node
and
time_is_up
(
instance
.
launch_time
,
self
.
_current_time
,
self
.
_config
.
orphaned_instance_timeout
instances_to_terminate
.
append
(
instance
.
id
)
if
instances_to_terminate
:
log
.
info
(
"Terminating orphaned instances"
)
self
.
_instance_manager
.
delete_instances
(
instances_to_terminate
,
terminate_batch_size
=
self
.
_config
.
terminate_max_batch_size
def
_enter_protected_mode
(
self
,
partitions_to_disable
):
"""Entering protected mode if no active running job in queue."""
# Place partitions into inactive
log
.
info
(
"Placing bootstrap failure partitions to INACTIVE: %s"
,
partitions_to_disable
)
update_partitions
(
partitions_to_disable
,
PartitionStatus
.
INACTIVE
)
# Change compute fleet status to protected
if
not
ComputeFleetStatus
.
is_protected
(
self
.
_compute_fleet_status
):
"Setting cluster into protected mode due to failures detected in node provisioning. "
"Please investigate the issue and then use 'pcluster update-compute-fleet --status START_REQUESTED' "
"command to re-enable the fleet."
self
.
_update_compute_fleet_status
(
ComputeFleetStatus
.
PROTECTED
)
def
_reset_partition_failure_count
(
self
,
partition
):
"""Reset bootstrap failure count for partition which has bootstrap failure nodes successfully launched."""
log
.
info
(
"Find successfully launched node in partition %s, reset partition protected failure count"
,
partition
)
self
.
_partitions_protected_failure_count_map
.
pop
(
partition
,
None
)
def
_handle_bootstrap_failure_nodes
(
self
,
active_nodes
):
Find bootstrap failure nodes and increase partition failure count.
There are two kinds of bootstrap failure nodes:
Nodes fail in bootstrap during health check,