[python]Paramiko’s SSHClient

Paramiko SSH client

I wrote a sub class from paramiko in order add on functionality to assist myself to work with Ansible AWX 9.2.0 (dockerless version).

The subclass works with my CentOS which hosts the Ansible AWX, the purpose is to use the SSHClient class to download and upload playbooks, check the existence of project folder if folder does not exist create one, remove the directory or file on remote server.

When uploading a file to remote server, a sha256 digest of the file is calculated, the script then uploads the file and its digest to the remote server.

SSHClient has a built in logger to log the sessions of the file, by default logging is disabled, user can instantiate a SSHClient with logger enabled.

Deep dive into the subclass code

This section describes the methods of the subclass of SSHClient, this is my first time using paramiko, setting up ssh session with linux devices is easier, I have not tried to use netmiko, but since netmiko also uses paramiko I believe netmiko can do the same thing as well.

Modules imported

from paramiko import SSHClient, AutoAddPolicy
from paramiko.ssh_exception import SSHException
from paramiko.util import log_to_file
from typing import Dict, Union, List
from pathlib import Path
import os
from types import MappingProxyType
import sys

from paramiko import SSHClient, AutoAddPolicy both are classes of Paramiko, SSHClient is the class for doing the SSH thing, using this SSHClient you do not need to worry about setting up the transport as the class will do it for you, my subclass is inherited from SSHClient. Also once a SSHClient instance is created you can easily access the SFTPClient by using the open_sftp() method.

The AutoAddPolicy is a class that help to add known hosts, reading the paramiko code the policy is set to RejectPolicy(), but in my subclass I changed self._policy to use AutoAddPolicy().

The from paramiko.ssh_exception import SSHException is the parent of all exceptions related to SSH this includes AuthenticationException

The from paramiko.util import log_to_file is a function to turn on the logger, the function is based from the logging module. Reading the code, you will see log_to_file requires two arguments one is the filename the other is a log level which is DEBUG by default.

The from typing import Dict, Union, List is for type hinting, I am learning how to properly document the code, and with type hinting my pycharm alerts me when I put in the wrong parameters or return the unintended data type.

The from pathlib import Path is used to get the home directory of the current OS, this can also be used to change directory like this Path(__file__).parent.absolute this will immediately get your current absolute path where you file is, parent method can be concatenated to move to the top of the current file directory.

I am using the join functionality from import os, to join the path os.path.join("path1", "path2"), “path1” and “path” will then be join in the directory format depending on the OS your script runs. In the past I used from os.path import join but later I found that this is not a good method, this is because there may be some other modules that have the same name or I need to write a function that is the same name to make my code more explicit, also by using the full path of the library I can better understand what join is actually doing and which library it is actually from.

I am using from types import MappingProxyType to create a dictionary or map object that cannot be assigned, I am trying to build an immutable dictionary and I came across this library. Once MappingProxyType object is created no new assignment is allowed, I required this as I need immutable dictionary for referencing to make my code more human readable. On later sections I will show why I need to use MappingProxyType.

I am using import sys to access the exit() and stdout.write().

This is my function for progress bar during upload and download files from helper.progress_bar import progress_bar

Constants

# Consolidate Errors
CONN_EXCEPTION = SSHException, TypeError, PermissionError

# Home directory for OS, works on Linux and Windows, not sure about others.
HOME_PATH = str(Path.home())

# https://docs.python.org/2/library/logging.html#levels
LOG_LEVEL = MappingProxyType(
    {
        "NOTSET": 0,
        "DEBUG": 10,
        "INFO": 20,
        "WARNING": 30,
        "ERROR": 40,
        "CRITICAL": 50
    }
)

CONN_EXCEPTION = SSHException, TypeError, PermissionError this tuple of errors reduce the need to have multiple exceptions for different exceptions and also because my messages are the same, the only difference is the error message which python will sort it out.

HOME_PATH = str(Path.home()) this is the home directory if in Windows it looks like C:\Users\username if in linux it is /home/username.

I am using MappingProxyType to have a dictionary to refer to the log level id in integer.

# https://docs.python.org/2/library/logging.html#levels
LOG_LEVEL = MappingProxyType(
    {
        "NOTSET": 0,
        "DEBUG": 10,
        "INFO": 20,
        "WARNING": 30,
        "ERROR": 40,
        "CRITICAL": 50
    }
)

To access the value, you simply do LOG_LEVEL["INFO"] to change the log_to_file to log anything that is informational.

Function that gets the hash of a file

This is a function that is outside the SHHClient subclass, the purpose to be put outside as this function can run on itself not only for the SSHClient but for other purposes. This function is used by the download and upload methods of SSHClient subclass which is abstracted from the user.

def get_file_hash(base_path: str = None, filename: str = None):
    """
    https://nitratine.net/blog/post/how-to-hash-files-in-python/
    The concept, read the target file block by block, each block size is 65535 bytes.
    On each block calculate the hash until the EOF.
    Then write the entire digest back to a file.
    I have tested between 8KB of python script file and 600MB of centos iso file and both results are good, also
    on centos i have use the sha256sum to calculate the digests which are the same with the hash files uploaded.
    :param filename:
        name of the file
    :param base_path:
        base directory of the file, I use this because I need the base_path to write the hash to file.
    :return:
    """
    file_path = os.path.join(base_path, filename)

    # normally loaded at first, but i prefer the library to be loaded if in use.
    import hashlib

    BLOCK_SIZE = 65535
    digest = hashlib.sha256()
    with open(file_path, "rb") as file:
        # read the first 65535 bytes from the file.
        file_blocks = file.read(BLOCK_SIZE)
        while len(file_blocks) > 0:
            digest.update(file_blocks)
            file_blocks = file.read(BLOCK_SIZE)
    with open(os.path.join(base_path, f"{filename}.sha256"), "w") as write_hash:
        write_hash.write(digest.hexdigest())
    return f"{filename}.sha256"

I have imported the hashlib inside the function as I wished this to be used if the function is called.

You can refer to the blog within the comment of the code to find out how the entire block works.
In summary, file is read in blocks, and every blocks hash is calculated until the entire file is read and calculated. The hash is then stored in the hash object as _hashlib.HASH object to make it into a string object use the hexdigest() method like this digest.hexdigest(), then this string is written into a file.

SSHClient subclass initialization

LinuxSSH is the subclass name of SSHClient, I have seen a lot of codes in github that renames the class to another name like this:

class SubClass(OriginalClass):
     pass

I am not sure the purpose, if I need the features of a class I normally use its original name.
I wrote a subclass is because I want to further abstract the way things are done so that my main code will be neater and more descriptive.

My subclass collects the username, password, hostname (default is localhost), port number (Default 22), log_file which is the name of the log, and level which is the log level.

During instantiation an attempt to connect to the hostname is started, if log_file and level are specified the log will be will be enabled. AutoAddPolicy is used as the default instead of RejectPolicy, note that self._policy is the original parameter of SSHClient.

A lot of examples use the method set_missing_host_key_policy, this method does the same thing as I did in the code.

class LinuxSSH(SSHClient):
    """
    This is a sub class of SSHClient, this is to add on some methods while still having the functionalities of
    SSHClient.
    """

    def __init__(self, username: str = None,
                 password: str = None,
                 hostname: str = "127.0.0.1",
                 port: int = 22,
                 log_file: str = None,
                 level: int = None):
        super().__init__()
        if log_file is not None and level is not None:
            log_to_file(log_file, level=level)
        self.username = username
        self.password = password
        self.hostname = hostname
        self.port = port
        self._policy = AutoAddPolicy()  # modify the original default RejectPolicy()
        # Once an instance is created, a ssh session is attempted.
        try:
            self.connect(self.hostname, port=self.port, username=self.username, password=self.password)
        except CONN_EXCEPTION as CE:
            sys.stdout.write(f"Message: {str(CE)}")
            sys.exit(1)

Get the project directories

Ansible AWX job creation requires you to attach a project folder if the job type is manual, so this method helps me to get the available playbook directories in a list, of course I deliberately change the response to dictionary as I find dictionary is more controllable and easier to reference.

Reading the code, I found that the exec_command method returns a tuple stdin, stdout, stderr, the more useful one is stdin this enables me to put in my sudo password in the input stream, the stdout captures the output from the exec_command method. To read the stdout use the read() method then decode("utf-8) to convert the stream into string.

The password stdin.write(self.password + "\n") is leaked in stdout.read().decode("utf-8"), which is why I skipped the first 3 indexes stdout_results.splitlines()[3:], index 0 is the sudo password, index 1 is the sudo prompt for password, index 2 is the total 0 these three are not what I need.

The directory name is found at the end of each line of stdout_results.splitlines()[3:], hence in order to get this directory I need to do this:

for d in stdout_results.splitlines()[3:]:  # Not interested in password, sudo prompt and total 0
                pbdirs.append(d.split()[-1])  # Only directory name.

But the directories will also include the . and .., hence to return the actual directory I skipped the dots by doing pbdirs[2:], this technique is slice and dice which is very essential for manipulating iterables (string, tuple, list).

def get_project_dirs(self, dirname: str = "/var/lib/awx/projects") -> Union[Dict[str, str], Dict[str, List]]:
        """
        To get a list of directories under the Ansible AWX base project directory.
        :param dirname:
            Project base directory. Ansible searches the yaml file from base directory.
        :return:
            dictionary of response.
        """
        try:
            # exec_command throws up a tuple(stdin, stdout, stderr)
            stdin, stdout, _ = self.exec_command(f"sudo ls -lah {dirname}", get_pty=True)
            stdin.write(self.password + "\n")
            stdin.flush()
            stdout_results = stdout.read().decode("utf-8")
            pbdirs = list()
            for d in stdout_results.splitlines()[3:]:  # Not interested in password, sudo prompt and total 0
                pbdirs.append(d.split()[-1])  # Only directory name.
            return {
                "status": "success",
                "playbook_dirs": pbdirs[2:] if len(pbdirs) > 2 else []  # not interested in . and ..
            }
        except CONN_EXCEPTION as CE:
            return {
                "status": "failed",
                "message": str(CE)
            }

Make the project directory

This method creates the project directory if the directory does not exist, this method uses get_project_dirs for verification before creating the directory.

response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # a list of directory names where the playbook is stored.
            pbdirs = response["playbook_dirs"]

Because sudo is used to create directory, a change owner is required so that awx can use the directory.

commands = [f"sudo mkdir {base_path}/{dirname}",
           f"sudo chown -R awx:awx {base_path}/{dirname}"]

The entire method code is below:

    def create_project_dir(self, base_path: str = "/var/lib/awx/projects", dirname: str = None):
        """
        Create project directory.
        :param base_path:
            Project base path
        :param dirname:
            New project directory, this directory is attached to the manual project of Ansible AWX
        :return:
        """
        response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # a list of directory names where the playbook is stored.
            pbdirs = response["playbook_dirs"]
        else:
            return response
        if dirname not in pbdirs or pbdirs == list():
            # create the directory if not exists.
            commands = [f"sudo mkdir {base_path}/{dirname}",
                        f"sudo chown -R awx:awx {base_path}/{dirname}"]
            try:
                for command in commands:
                    stdin, stdout, _ = self.exec_command(command, get_pty=True)
                    # sudo password
                    stdin.write(self.password + "\n")
                    stdin.flush()
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }
        else:
            return {
                "status": "failed",
                "message": f"{dirname} exists."
            }

Remove project directory

This method uses get_project_dirs to check if the directory to be deleted exists or not, if not a failure response is returned.

response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # list of directory names under project base directory.
            pbdirs = response["playbook_dirs"]
        else:
            return response

I am using the linux directory format, in this situation I cannot use the os.path.join because I am using a Windows system which will change the path to Windows’ path.

command = f"sudo rm -rf {base_path}/{dirname}"

The method code as below:

    def remove_project_dir(self, base_path: str = "/var/lib/awx/projects", dirname: str = None):
        """
        This is for project directory removal, project directory is required to be attached to Ansible AWX
        project.
        :param base_path:
            Project base path.
        :param dirname:
            directory name under the base path.
        :return:
        """
        if dirname is None:
            return {
                "status": "failed",
                "message": "You have forgotten to provide the project dir name you want to remove."
            }
        response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # list of directory names under project base directory.
            pbdirs = response["playbook_dirs"]
        else:
            return response
        if dirname in pbdirs:
            # if the requested directory for removal exists, prepare the rm command.
            command = f"sudo rm -rf {base_path}/{dirname}"
            try:
                stdin, stdout, stderr = self.exec_command(command, get_pty=True)
                # sudo password
                stdin.write(self.password + "\n")
                stdin.flush()
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }
        else:
            return {
                "status": "failed",
                "message": f"{dirname} does not exists."
            }

Download file from remote server

This method downloads the file from the remote server, the code uses self.open_sftp() to create a SFTPClient instance, this method is part of the parent class – SSHClient.

The below code uses the with context to open and close the SFTP session so the user does not need to take care of opening and closing session.

        with self.open_sftp() as sftp:
            """
            The sftp open and close session is handled here, so user does not need to close the sftp session.
            The purpose is to have a more straightforward way to download/upload files to target server.
            """
            # B is byte, b is bit. Miniters will have a better looking bar.
            # ascii=True will give # as the bar.
            callback, pbar = progress_bar(unit="B", unit_scale=True, miniters=1)
            try:
                sftp.get(src_abs_path, local_path, callback=callback)
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }

The download method code as follows:

    def download(self, src_abs_path: str = None,
                 dst_path: str = HOME_PATH,
                 dst_filename: str = None):
        """
        This method does sftp download, SSHClient can easily create a SFTPClient object by calling
        self.open_sftp(), once a SFTPClient instance is created we can use the get method to download.
        :param src_abs_path:
            Absolute path of the file you wish to download from the remote CentOS server.
        :param dst_path:
            Local path of your computer where this script is executed.
        :param dst_filename:
            Filename to be created with the download object, this is optional, you can put in the full path
            with the file name in local_base_path.
        :return:
        """
        if dst_filename is not None:
            local_path = os.path.join(dst_path, dst_filename)
        else:
            local_path = dst_path
        with self.open_sftp() as sftp:
            """
            The sftp open and close session is handled here, so user does not need to close the sftp session.
            The purpose is to have a more straightforward way to download/upload files to target server.
            """
            # B is byte, b is bit. Miniters will have a better looking bar.
            # ascii=True will give # as the bar.
            callback, pbar = progress_bar(unit="B", unit_scale=True, miniters=1)
            try:
                sftp.get(src_abs_path, local_path, callback=callback)
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }

Upload file from local system to remote server

This method uses get_file_hash function to store the hash of the file to be uploaded, then upload the file and the hash file over to the remote server.
Similar to download method the upload method takes care of opening and closing of SFTP session using with context.

    def upload(self, src_path: str = HOME_PATH,
               src_filename: str = None,
               dst_abs_path: str = None):
        """
        This method uploads the file from your computer to remote server.
        :param src_filename:
            src_filename, this is for used with get_file_hash.
        :param src_path:
            the base path of where the src_filename can be found.
        :param dst_abs_path:
            The absolute path the file will be uploaded. os.path.join works properly in source, if the remote is
            a different OS the path will be wrong.
        :return:
        """
        if src_filename is not None:
            local_path = os.path.join(src_path, src_filename)
        else:
            local_path = src_path
        remote_path = dst_abs_path
        digest_filename = get_file_hash(base_path=src_path, filename=src_filename)
        digest_abs_path = os.path.join(src_path, digest_filename)
        with self.open_sftp() as sftp:
            """
            The sftp open and close session is handled here, so user does not need to close the sftp session.
            The purpose is to have a more straightforward way to download/upload files to target server.
            """
            # b is bit, B is byte. If ascii=True, # will be used for progress bar.
            callback, pbar = progress_bar(unit="B", unit_scale=True, miniters=1)
            try:
                sftp.put(local_path, remote_path, callback=callback)
                # The session to remote_path is still on, hence only target filename is required.
                sftp.put(digest_abs_path, digest_filename, callback=callback)
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }

Progress bar

The SFTPClient get and put methods have a callback feature so that files transferring progress can be called back. To make it looks good like a progress bar, tqdm is used.

See the website in my code’s comment to understand how to integrate tqdm with paramiko.

from tqdm import tqdm
"""
References to use tqdm with paramiko:
https://github.com/tqdm/tqdm/issues/311
https://raw.githubusercontent.com/tqdm/tqdm/master/examples/tqdm_wget.py

The get and put method of sftp client in paramiko supports callback. The callback function format is like this
func(int, int), the first int is bytes for blocks transferred, the second int is blocks to be transferred (total).
"""


def progress_bar(*args, **kwargs):
    pbar = tqdm(*args, **kwargs)
    last = [0]  # last block transferred

    def progress_wrapper(transferred, to_be_transferred):
        pbar.total = int(to_be_transferred)
        pbar.update(int(transferred - last[0]))  # transferred subtract from last block transferred
        last[0] = transferred  # update last block transferred
    return progress_wrapper, pbar

The code block returns the references of the function that updates the progress bar and the reference of pbar to be used by the function to update the bar.

Usage example

Below is the testing code.

from helper.linux import LinuxSSH, LOG_LEVEL
from getpass import getpass
import sys
from datetime import datetime

# Timestamp
time_now = datetime.now()
timestamp = time_now.strftime('%b-%d-%Y_%H%M')

print("*"*10 + "Demonstration" + "*"*10)
hostname = input("Remote server address: ")
username = input(f"Username of {hostname}: ")
password = getpass(f"Password of username {username}: ")

# This way if there is "" the value will be 0 and int() will not throw exception.
# if there is a number the prefix 0 will be ignored.
port = int("0" + input("SSH port? Press  to accept default port 22: "))
if port == 0:
    port = 22
log_enable = input("Enable logging? (Y or N): ")
if log_enable.lower() == "y":
    log_filename = input("Log filename: ")
    log_level = input("Log level? [DEBUG, INFO, WARNING, CRITICAL, ERROR]: ")
elif log_enable.lower() == "n":
    log_filename = None
else:
    print("Unrecognized response, either Y or N.")
    sys.exit(1)

with LinuxSSH(username=username,
              password=password,
              hostname=hostname,
              port=port,
              log_file=f"{log_filename}-{timestamp}.log",
              level=LOG_LEVEL[log_level.upper()]) as linux:
    base_path = input(f"Set base path of {hostname}: ")
    dirname = input(f"Project directory name of {base_path}: ")
    print("*"*10 + "Creating directory" + "*"*10)
    linux.create_project_dir(base_path=base_path, dirname=dirname)
    dl_abs_path = input(f"Absolute path of the file you wish to download from {hostname} to your home directory: ")
    local_filename = input("The filename you wish to be in your home directory: ")
    print("*" * 10 + f"Downloading {dirname}" + "*" * 10)
    linux.download(src_abs_path=dl_abs_path, dst_filename=local_filename)
    ul_filename = input(f"Filename to upload to {hostname}: ")
    remote_abs_path = input(f"Remote absolute path including filename of {hostname}: ")
    print("*" * 10 + f"Uploading {dirname}" + "*" * 10)
    linux.upload(src_filename=ul_filename, dst_abs_path=remote_abs_path)

This is how it will look like:
p5p6

demo_dir is created as the subdirectory, and two files are uploaded from my windows machine which are centos.iso and centos.iso.sha256, Python file is the file which I downloaded to my home directory.
Below is in my C:\Users\cyruslab directory in my windows vm.
p8

The below is the sha256 verification, both have the same results and this means the file transferred is at its entirety.
p9

The entire code of SSHClient subclass

from paramiko import SSHClient, AutoAddPolicy
from paramiko.ssh_exception import SSHException
from paramiko.util import log_to_file
from typing import Dict, Union, List
from pathlib import Path
import os
from types import MappingProxyType
import sys

# The progress bar for download and upload files.
from helper.progress_bar import progress_bar

# Consolidate Errors
CONN_EXCEPTION = SSHException, TypeError, PermissionError

# Home directory for OS, works on Linux and Windows, not sure about others.
HOME_PATH = str(Path.home())

# https://docs.python.org/2/library/logging.html#levels
LOG_LEVEL = MappingProxyType(
    {
        "NOTSET": 0,
        "DEBUG": 10,
        "INFO": 20,
        "WARNING": 30,
        "ERROR": 40,
        "CRITICAL": 50
    }
)


def get_file_hash(base_path: str = None, filename: str = None):
    """
    https://nitratine.net/blog/post/how-to-hash-files-in-python/
    The concept, read the target file block by block, each block size is 65535 bytes.
    On each block calculate the hash until the EOF.
    Then write the entire digest back to a file.
    I have tested between 8KB of python script file and 600MB of centos iso file and both results are good, also
    on centos i have use the sha256sum to calculate the digests which are the same with the hash files uploaded.
    :param filename:
        name of the file
    :param base_path:
        base directory of the file, I use this because I need the base_path to write the hash to file.
    :return:
    """
    file_path = os.path.join(base_path, filename)

    # normally loaded at first, but i prefer the library to be loaded if in use.
    import hashlib

    BLOCK_SIZE = 65535
    digest = hashlib.sha256()
    with open(file_path, "rb") as file:
        # read the first 65535 bytes from the file.
        file_blocks = file.read(BLOCK_SIZE)
        while len(file_blocks) > 0:
            digest.update(file_blocks)
            file_blocks = file.read(BLOCK_SIZE)
    with open(os.path.join(base_path, f"{filename}.sha256"), "w") as write_hash:
        write_hash.write(digest.hexdigest())
    return f"{filename}.sha256"


class LinuxSSH(SSHClient):
    """
    This is a sub class of SSHClient, this is to add on some methods while still having the functionalities of
    SSHClient.
    """

    def __init__(self, username: str = None,
                 password: str = None,
                 hostname: str = "127.0.0.1",
                 port: int = 22,
                 log_file: str = None,
                 level: int = None):
        super().__init__()
        if log_file is not None and level is not None:
            log_to_file(log_file, level=level)
        self.username = username
        self.password = password
        self.hostname = hostname
        self.port = port
        self._policy = AutoAddPolicy()  # modify the original default RejectPolicy()
        # Once an instance is created, a ssh session is attempted.
        try:
            self.connect(self.hostname, port=self.port, username=self.username, password=self.password)
        except CONN_EXCEPTION as CE:
            sys.stdout.write(f"Message: {str(CE)}")
            sys.exit(1)

    def get_project_dirs(self, dirname: str = "/var/lib/awx/projects") -> Union[Dict[str, str], Dict[str, List]]:
        """
        To get a list of directories under the Ansible AWX base project directory.
        :param dirname:
            Project base directory. Ansible searches the yaml file from base directory.
        :return:
            dictionary of response.
        """
        try:
            # exec_command throws up a tuple(stdin, stdout, stderr)
            stdin, stdout, _ = self.exec_command(f"sudo ls -lah {dirname}", get_pty=True)
            stdin.write(self.password + "\n")
            stdin.flush()
            stdout_results = stdout.read().decode("utf-8")
            pbdirs = list()
            for d in stdout_results.splitlines()[3:]:  # Not interested in password, sudo prompt and total 0
                pbdirs.append(d.split()[-1])  # Only directory name.
            return {
                "status": "success",
                "playbook_dirs": pbdirs[2:] if len(pbdirs) > 2 else []  # not interested in . and ..
            }
        except CONN_EXCEPTION as CE:
            return {
                "status": "failed",
                "message": str(CE)
            }

    def create_project_dir(self, base_path: str = "/var/lib/awx/projects", dirname: str = None):
        """
        Create project directory.
        :param base_path:
            Project base path
        :param dirname:
            New project directory, this directory is attached to the manual project of Ansible AWX
        :return:
        """
        response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # a list of directory names where the playbook is stored.
            pbdirs = response["playbook_dirs"]
        else:
            return response
        if dirname not in pbdirs or pbdirs == list():
            # create the directory if not exists.
            commands = [f"sudo mkdir {base_path}/{dirname}",
                        f"sudo chown -R awx:awx {base_path}/{dirname}"]
            try:
                for command in commands:
                    stdin, stdout, _ = self.exec_command(command, get_pty=True)
                    # sudo password
                    stdin.write(self.password + "\n")
                    stdin.flush()
                    stdout.read().decode("utf-8")
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }
        else:
            return {
                "status": "failed",
                "message": f"{dirname} exists."
            }

    def remove_project_dir(self, base_path: str = "/var/lib/awx/projects", dirname: str = None):
        """
        This is for project directory removal, project directory is required to be attached to Ansible AWX
        project.
        :param base_path:
            Project base path.
        :param dirname:
            directory name under the base path.
        :return:
        """
        if dirname is None:
            return {
                "status": "failed",
                "message": "You have forgotten to provide the project dir name you want to remove."
            }
        response = self.get_project_dirs(dirname=base_path)
        if response["status"] == "success":
            # list of directory names under project base directory.
            pbdirs = response["playbook_dirs"]
        else:
            return response
        if dirname in pbdirs:
            # if the requested directory for removal exists, prepare the rm command.
            command = f"sudo rm -rf {base_path}/{dirname}"
            try:
                stdin, stdout, stderr = self.exec_command(command, get_pty=True)
                # sudo password
                stdin.write(self.password + "\n")
                stdin.flush()
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }
        else:
            return {
                "status": "failed",
                "message": f"{dirname} does not exists."
            }

    def download(self, src_abs_path: str = None,
                 dst_path: str = HOME_PATH,
                 dst_filename: str = None):
        """
        This method does sftp download, SSHClient can easily create a SFTPClient object by calling
        self.open_sftp(), once a SFTPClient instance is created we can use the get method to download.
        :param src_abs_path:
            Absolute path of the file you wish to download from the remote CentOS server.
        :param dst_path:
            Local path of your computer where this script is executed.
        :param dst_filename:
            Filename to be created with the download object, this is optional, you can put in the full path
            with the file name in local_base_path.
        :return:
        """
        if dst_filename is not None:
            local_path = os.path.join(dst_path, dst_filename)
        else:
            local_path = dst_path
        with self.open_sftp() as sftp:
            """
            The sftp open and close session is handled here, so user does not need to close the sftp session.
            The purpose is to have a more straightforward way to download/upload files to target server.
            """
            # B is byte, b is bit. Miniters will have a better looking bar.
            # ascii=True will give # as the bar.
            callback, pbar = progress_bar(unit="B", unit_scale=True, miniters=1)
            try:
                sftp.get(src_abs_path, local_path, callback=callback)
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }

    def upload(self, src_path: str = HOME_PATH,
               src_filename: str = None,
               dst_abs_path: str = None):
        """
        This method uploads the file from your computer to remote server.
        :param src_filename:
            src_filename, this is for used with get_file_hash.
        :param src_path:
            the base path of where the src_filename can be found.
        :param dst_abs_path:
            The absolute path the file will be uploaded. os.path.join works properly in source, if the remote is
            a different OS the path will be wrong.
        :return:
        """
        if src_filename is not None:
            local_path = os.path.join(src_path, src_filename)
        else:
            local_path = src_path
        remote_path = dst_abs_path
        digest_filename = get_file_hash(base_path=src_path, filename=src_filename)
        digest_abs_path = os.path.join(src_path, digest_filename)
        with self.open_sftp() as sftp:
            """
            The sftp open and close session is handled here, so user does not need to close the sftp session.
            The purpose is to have a more straightforward way to download/upload files to target server.
            """
            # b is bit, B is byte. If ascii=True, # will be used for progress bar.
            callback, pbar = progress_bar(unit="B", unit_scale=True, miniters=1)
            try:
                sftp.put(local_path, remote_path, callback=callback)
                # The session to remote_path is still on, hence only target filename is required.
                sftp.put(digest_abs_path, digest_filename, callback=callback)
            except CONN_EXCEPTION as CE:
                return {
                    "status": "failed",
                    "message": str(CE)
                }
Advertisement

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s