Source code for graflag.ssh

"""SSH operations for GraFlag."""

import subprocess
from pathlib import Path
from typing import List
import logging

logger = logging.getLogger(__name__)


[docs] class SSHManager: """Handle SSH operations to remote manager.""" def __init__(self, manager_ip: str, ssh_port: str = "22", ssh_key: str = None): """Initialize SSH manager.""" self.manager_ip = manager_ip self.ssh_port = ssh_port self.ssh_key = ssh_key
[docs] def execute(self, command: str, capture_output: bool = True) -> subprocess.CompletedProcess: """Execute command on manager via SSH.""" ssh_cmd = f"ssh -i {self.ssh_key} -p {self.ssh_port} -o StrictHostKeyChecking=no root@{self.manager_ip} '{command}'" logger.debug(f"Executing SSH command: {ssh_cmd}") return subprocess.run( ssh_cmd, shell=True, capture_output=capture_output, text=True )
[docs] def path_exists(self, remote_shared_dir: str, path: str) -> bool: """Check if path exists on remote manager.""" result = self.execute(f"test -e {remote_shared_dir}/{path}") return result.returncode == 0
[docs] def read_file(self, remote_shared_dir: str, path: str) -> str: """Read file content from remote manager.""" result = self.execute(f"cat {remote_shared_dir}/{path}") if result.returncode == 0: return result.stdout return ""
[docs] def mkdir(self, remote_shared_dir: str, path: str) -> bool: """Create directory on remote manager.""" result = self.execute(f"mkdir -p {remote_shared_dir}/{path}") return result.returncode == 0
[docs] def list_dir(self, remote_shared_dir: str, path: str) -> List[str]: """List directory contents on remote manager.""" result = self.execute(f"ls -1 {remote_shared_dir}/{path} 2>/dev/null || true") if result.returncode == 0 and result.stdout.strip(): return [ item.strip() for item in result.stdout.strip().split("\n") if item.strip() ] return []
[docs] def copy_files(self, source_paths, dest_path: str, recursive: bool = False, from_remote: bool = False) -> str: """ Copy files/directories bidirectionally via rsync. Args: source_paths: Source path(s) - can be single string or list dest_path: Destination path recursive: Include recursive flag (automatically added for directories) from_remote: If True, copy from remote to local; if False (default), copy from local to remote Returns: Destination path """ # Handle single string or list of paths if isinstance(source_paths, str): source_paths = [source_paths] if from_remote: # Copy from remote to local return self._copy_from_remote(source_paths, dest_path, recursive) else: # Copy from local to remote return self._copy_to_remote(source_paths, dest_path, recursive)
def _copy_to_remote(self, local_paths, remote_dest: str, recursive: bool = False) -> str: """Copy files/directories from local to remote via rsync.""" # Validate all local paths exist local_path_objs = [] for local_path in local_paths: local_path_obj = Path(local_path).expanduser() if not local_path_obj.exists(): raise FileNotFoundError(f"Local path does not exist: {local_path}") local_path_objs.append(local_path_obj) # Ensure remote destination directory exists parent_dir = str(Path(remote_dest).parent) self.execute(f"mkdir -p {parent_dir}") logger.info(f"[INFO] Copying {len(local_paths)} item(s) to {self.manager_ip}:{remote_dest}") # Build rsync command - more robust than scp rsync_parts = ["rsync", "-avz", "--progress", "--force"] # SSH options ssh_opts = ["-o", "StrictHostKeyChecking=no"] if self.ssh_key: key_path = Path(self.ssh_key).expanduser() if str(key_path).endswith('.pub'): key_path = key_path.with_suffix('') ssh_opts.extend(["-i", str(key_path)]) ssh_opts.extend(["-p", self.ssh_port]) rsync_parts.extend(["-e", f"ssh {' '.join(ssh_opts)}"]) # Add all source paths for local_path_obj in local_path_objs: rsync_parts.append(str(local_path_obj)) # Add destination rsync_parts.append(f"root@{self.manager_ip}:{remote_dest}") logger.debug(f"Executing rsync command: {' '.join(rsync_parts)}") result = subprocess.run(rsync_parts, capture_output=True, text=True) if result.returncode == 0: logger.info(f"[OK] Successfully copied {len(local_paths)} item(s) to {remote_dest}") else: raise RuntimeError(f"Failed to copy files with rsync: {result.stderr}") return remote_dest def _copy_from_remote(self, remote_paths, local_dest: str, recursive: bool = False) -> str: """Copy files/directories from remote to local via rsync.""" # Ensure local destination directory exists local_dest_obj = Path(local_dest).expanduser() local_dest_obj.parent.mkdir(parents=True, exist_ok=True) logger.info(f"[INFO] Copying {len(remote_paths)} item(s) from {self.manager_ip} to {local_dest}") # Build rsync command rsync_parts = ["rsync", "-avz", "--progress", "--force"] # SSH options ssh_opts = ["-o", "StrictHostKeyChecking=no"] if self.ssh_key: key_path = Path(self.ssh_key).expanduser() if str(key_path).endswith('.pub'): key_path = key_path.with_suffix('') ssh_opts.extend(["-i", str(key_path)]) ssh_opts.extend(["-p", self.ssh_port]) rsync_parts.extend(["-e", f"ssh {' '.join(ssh_opts)}"]) # Add all source paths (remote) for remote_path in remote_paths: rsync_parts.append(f"root@{self.manager_ip}:{remote_path}") # Add destination (local) rsync_parts.append(str(local_dest_obj)) logger.debug(f"Executing rsync command: {' '.join(rsync_parts)}") result = subprocess.run(rsync_parts, capture_output=True, text=True) if result.returncode == 0: logger.info(f"[OK] Successfully copied {len(remote_paths)} item(s) to {local_dest}") else: raise RuntimeError(f"Failed to copy files with rsync: {result.stderr}") return str(local_dest_obj)