Skip to content

Commit c3f49de

Browse files
authored
add sandbox.sh (Metta-AI#234)
# Replace SSH with SSM for AWS Batch Instance Access This PR replaces SSH-based access to AWS Batch instances with AWS Systems Manager (SSM) Session Manager, which provides more secure access without requiring SSH keys or public IP addresses. ## Changes: - Renamed `get_job_ip()` to `get_job_instance_id()` to reflect its updated purpose - Replaced `ssh_to_job()` with `ssm_to_job()` to use SSM instead of SSH - Added a comprehensive `setup_ssm.sh` script to help configure SSM access - Updated error messages to reflect SSM-specific issues - Added "sandbox" as a valid command option in `launch_task.py` - Added a simple `sandbox.sh` script for development environments - Added the session-manager-plugin to the Brewfile for macOS users These changes improve security by eliminating the need for SSH keys and public IP addresses while providing a more reliable connection method to AWS Batch instances. This was mostly vibe-coded but it's overriding previously low-quality scripts. So don't review too carefully.
1 parent f13fa00 commit c3f49de

File tree

6 files changed

+419
-47
lines changed

6 files changed

+419
-47
lines changed

devops/aws/batch/job.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -745,11 +745,10 @@ def launch_job(job_queue=None):
745745
return False
746746

747747

748-
def get_job_ip(job_id_or_name):
749-
"""Get the public IP address of the instance running a job."""
748+
def get_job_instance_id(job_id_or_name):
749+
"""Get the instance ID of the instance running a job."""
750750
batch = get_boto3_client("batch")
751751
ecs = get_boto3_client("ecs")
752-
ec2 = get_boto3_client("ec2")
753752

754753
try:
755754
# First try to get the job by ID
@@ -801,7 +800,7 @@ def get_job_ip(job_id_or_name):
801800

802801
# Check if it's a multi-node job
803802
if "nodeProperties" in job:
804-
print(f"Error: Job '{job['jobId']}' is a multi-node job. SSH is not supported for multi-node jobs.")
803+
print(f"Error: Job '{job['jobId']}' is a multi-node job. SSM is not supported for multi-node jobs.")
805804
return None
806805

807806
# Get the task ARN and cluster
@@ -830,75 +829,84 @@ def get_job_ip(job_id_or_name):
830829
)
831830
ec2_instance_id = container_instance_desc["containerInstances"][0]["ec2InstanceId"]
832831

833-
# Get the public IP address
834-
instances = ec2.describe_instances(InstanceIds=[ec2_instance_id])
835-
if "PublicIpAddress" in instances["Reservations"][0]["Instances"][0]:
836-
public_ip = instances["Reservations"][0]["Instances"][0]["PublicIpAddress"]
837-
return public_ip
838-
else:
839-
print(f"No public IP address found for job '{job['jobId']}'")
840-
return None
832+
return ec2_instance_id
841833

842834
except Exception as e:
843-
print(f"Error retrieving job IP: {str(e)}")
835+
print(f"Error retrieving job instance ID: {str(e)}")
844836
return None
845837

846838

847-
def ssh_to_job(job_id_or_name, instance_only=False):
848-
"""Connect to the instance running a job via SSH.
839+
def ssm_to_job(job_id_or_name, instance_only=False):
840+
"""Connect to the instance running a job via SSM.
849841
850842
Args:
851843
job_id_or_name: The job ID or name to connect to
852844
instance_only: If True, connect directly to the instance without attempting to connect to the container
853845
"""
854846

855-
# Get the IP address of the job
856-
ip = get_job_ip(job_id_or_name)
857-
if not ip:
847+
# Get the instance ID of the job
848+
instance_id = get_job_instance_id(job_id_or_name)
849+
if not instance_id:
858850
return False
859851

860852
try:
861-
# Establish SSH connection and check if it's successful
862-
print(f"Checking SSH connection to {ip}...")
863-
ssh_check_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=10 {ip} 'echo Connected'"
864-
ssh_check_output = subprocess.check_output(ssh_check_cmd, shell=True).decode().strip()
865-
if ssh_check_output != "Connected":
866-
raise subprocess.CalledProcessError(1, "SSH connection check failed")
853+
# Start SSM session
854+
print(f"Starting SSM session to instance {instance_id}...")
867855

868856
if instance_only:
869-
# Connect directly to the instance
870-
print(f"Connecting directly to the instance at {ip}...")
871-
ssh_cmd = f"ssh -o StrictHostKeyChecking=no -t {ip}"
872-
subprocess.run(ssh_cmd, shell=True)
857+
# Connect directly to the instance as root
858+
print(f"Connecting directly to the instance {instance_id} as root...")
859+
ssm_cmd = (
860+
f"aws ssm start-session --target {instance_id} "
861+
"--document-name AWS-StartInteractiveCommand "
862+
"--parameters 'command=sudo su -'"
863+
)
864+
subprocess.run(ssm_cmd, shell=True)
873865
else:
874866
# Retrieve container ID
875-
print(f"Finding container on {ip}...")
876-
container_cmd = f"ssh -o StrictHostKeyChecking=no -t {ip} \"docker ps | grep 'metta'\""
867+
print(f"Finding container on instance {instance_id}...")
868+
container_cmd = (
869+
f"aws ssm start-session --target {instance_id} "
870+
"--document-name AWS-StartInteractiveCommand "
871+
"--parameters 'command=sudo docker ps | grep metta'"
872+
)
877873
container_id_output = subprocess.check_output(container_cmd, shell=True).decode().strip()
878874

879-
if container_id_output:
880-
container_id = container_id_output.split()[0]
881-
print(f"Connecting to container {container_id} on {ip}...")
882-
exec_cmd = f'ssh -o StrictHostKeyChecking=no -t {ip} "docker exec -it {container_id} /bin/bash"'
875+
# Split output into lines and skip header
876+
container_lines = container_id_output.split("\n")
877+
if len(container_lines) > 1: # Skip header line
878+
container_id = container_lines[1].split()[0] # First column is container ID
879+
print(f"Connecting to container {container_id} on instance {instance_id}...")
880+
exec_cmd = (
881+
f"aws ssm start-session --target {instance_id} "
882+
"--document-name AWS-StartInteractiveCommand "
883+
f"--parameters 'command=sudo docker exec -it {container_id} /bin/bash'"
884+
)
883885
subprocess.run(exec_cmd, shell=True)
884886
else:
885-
print(f"No container running the 'mettaai/metta' image found on the instance {ip}.")
886-
print("Connecting to the instance directly...")
887-
ssh_cmd = f"ssh -o StrictHostKeyChecking=no -t {ip}"
888-
subprocess.run(ssh_cmd, shell=True)
887+
print(f"No container running the 'mettaai/metta' image found on the instance {instance_id}.")
888+
print("Connecting to the instance directly as root...")
889+
ssm_cmd = (
890+
f"aws ssm start-session --target {instance_id} "
891+
"--document-name AWS-StartInteractiveCommand "
892+
"--parameters 'command=sudo su -'"
893+
)
894+
subprocess.run(ssm_cmd, shell=True)
889895

890896
return True
891897
except subprocess.CalledProcessError as e:
892898
print(f"Error: {str(e)}")
893899
if "Connection timed out" in str(e):
894-
print(f"SSH connection to {ip} timed out. Please check the instance status and network connectivity.")
895-
elif "Connection refused" in str(e):
900+
print(f"SSM connection to {instance_id} timed out. Please check the instance status and IAM permissions.")
901+
elif "AccessDeniedException" in str(e):
896902
print(
897-
f"SSH connection to {ip} was refused. Please check if the instance is running and accepts SSH "
898-
"connections."
903+
f"SSM connection to {instance_id} was denied. Please check if "
904+
"the instance has the required IAM permissions and if the SSM "
905+
"agent is running."
899906
)
900907
else:
901908
print(
902-
f"An error occurred while connecting to {ip}. Please check the instance status and SSH configuration."
909+
f"An error occurred while connecting to {instance_id}. Please "
910+
"check the instance status and SSM configuration."
903911
)
904912
return False

devops/aws/batch/launch_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def main():
243243
parser = argparse.ArgumentParser(description="Launch an AWS Batch task with a wandb key.")
244244
parser.add_argument("--cluster", default="metta", help="The name of the ECS cluster.")
245245
parser.add_argument("--run", required=True, help="The run id.")
246-
parser.add_argument("--cmd", required=True, choices=["train", "sweep", "evolve"], help="The command to run.")
246+
parser.add_argument(
247+
"--cmd", required=True, choices=["train", "sweep", "evolve", "sandbox"], help="The command to run."
248+
)
247249

248250
parser.add_argument(
249251
"--git-branch",

0 commit comments

Comments
 (0)