feature: made User and Group classes inherit from YAMLObject class for ease in load/dump operations and conversion

This commit is contained in:
2026-01-22 17:10:15 -05:00
parent 8e6346ee21
commit 66da080109

View File

@@ -1,57 +1,98 @@
from typing import Self, Literal, Never, Callable, Sequence from typing import Self, Literal, Never, Callable, Sequence
from custtypes import Roles, Scopes, PathCollection, PathRoles, Software, SoftwareRoles, ExecutedPath, PacMans, UserName, GroupName, IdlePath, ExecutedPath from custtypes import Roles, Scopes, PathCollection, PathRoles
from custtypes import Software, SoftwareRoles, PacMans
from custtypes import ExecutedPath, IdlePath
from custtypes import UserName, GroupName
from pathlib import Path, PurePath from pathlib import Path, PurePath
from sshkey import SSHKeyCollection, SSHKeyType, SSHKey from sshkey import SSHKeyCollection, SSHKeyType, SSHKey
from random import choice as gamble from random import choice as gamble
from re import Pattern as RegEx from re import Pattern as RegEx
from softman import sshd from softman import sshd
from yaml import YAMLObject
class Group:
def __init__(self, group_name: GroupName = GroupName.sudo, gid: int = 27):
self.group_name = group_name
self.id = gid
self.category: Literal["system", "regular"] = "system"
class User: class Group(YAMLObject):
def __init__(self, username: UserName | str = UserName.root.name.lower(), password: str = "test", services: list = [Software.openssh_server.name.lower()], uid: int = 0): yaml_tag = u"!Group"
# @TODO create Enum class child for category parameter type hinting in below method
def __init__(self, group_name: GroupName = GroupName.sudo, category: Literal["system", "regular"] = "system", gid: int | str = 27):
if isinstance(group_name, GroupName):
self.group_name = group_name.name.lower()
else:
self.group_name = group_name
self.id = str(gid)
self.type = category
def __repr__(self):
return "%s(group_name=%r,category=%r,gid=%r)" % (
self.__class__.__name__,
self.group_name,
self.category,
self.id
)
class User(YAMLObject):
yaml_tag = u"!User"
def __init__(self, username: UserName | str = UserName.root, password: str = "test", services: list[str | Software] = [Software.openssh_server], uid: int | str = 0):
self.exists = True self.exists = True
self.username = username
self.id = uid if isinstance(username, UserName):
self.username = username.name.lower()
else:
self.username = username
self.id = str(uid)
self.password = password self.password = password
self.services: tuple = tuple(services)
new_services = []
for s in services:
if isinstance(s, Software):
new_services.append(s.name.lower())
else:
new_services.append(s)
self.services: tuple = tuple(new_services)
self.shell = "/bin/bash" self.shell = "/bin/bash"
self.home = "/" self.home = "/"
self.category: Literal["system", "regular"] = "regular" self.category: Literal["system", "regular"] = "regular"
group = Group(username, self.id) group = Group(username, self.id)
self.primary_group = group self.group = group
self.supp_groups = None self.groups: list[str | GroupName] | None = None
if self.supp_groups is None: if self.groups is None:
self.is_admin = True self.admin = True
elif isinstance(self.supp_groups, Sequence) and GroupName.sudo in self.supp_groups: elif isinstance(self.groups, Sequence) and GroupName.sudo in self.groups:
self.is_admin = True self.admin = True
else: else:
self.is_admin = False self.admin = False
ssh_keys = SSHKeyCollection() ssh_keys = SSHKeyCollection()
ssh_keys.pull() ssh_keys.pull()
self.__ssh_keys = ssh_keys self.__ssh_keys = ssh_keys
self.__public_keys = self.__ssh_keys.publish(SSHKeyType.pubkey, datatype=list)
self.__private_keys = self.__ssh_keys.publish(SSHKeyType.privkey, datatype=list)[0]
# print("here") # print("here")
self.__public_keys = self.__ssh_keys.publish(SSHKeyType.pubkey, datatype=list)
pubkeys = ssh_keys.publish(SSHKeyType.pubkey, datatype=list) pubkeys = ssh_keys.publish(SSHKeyType.pubkey, datatype=list)
self.__auth_keys: SSHKeyCollection = SSHKeyCollection() self.__auth_keys: SSHKeyCollection = SSHKeyCollection()
for p in pubkeys: for p in pubkeys:
self.__auth_keys.append(p) self.__auth_keys.append(p)
self.ssh_authorized_keys: list[str | None] = []
self.__private_keys = self.__ssh_keys.publish(SSHKeyType.privkey, datatype=list)[0]
privkeys = ssh_keys.publish(SSHKeyType.privkey, datatype=list) privkeys = ssh_keys.publish(SSHKeyType.privkey, datatype=list)
self.__priv_keys: SSHKeyCollection = SSHKeyCollection() self.__priv_keys: SSHKeyCollection = SSHKeyCollection()
for p in privkeys[0]: for p in privkeys[0]:
self.__priv_keys.append(p) self.__priv_keys.append(p)
self.ssh_private_key_paths: list[str | None] = []
self.__priv_key_pref: int = privkeys[1] self.__priv_key_pref: int = privkeys[1]
self.ssh_private_key_path_pref: int = privkeys[1]
self.__apps = (sshd,) self.__apps = (sshd,)
self.__ssh_keypairs: tuple[SSHKey | tuple[SSHKey]] | SSHKey = tuple() self.__ssh_keypairs: tuple[SSHKey | tuple[SSHKey]] | SSHKey = tuple()
self.__ssh_keypair_chosen = False self.__ssh_keypair_chosen = False
def get_app(self, name: Software) -> Never: def get_app(self, name: Software) -> Never:
raise NotImplementedError raise NotImplementedError
@@ -68,8 +109,8 @@ class User:
if self.username not in users: if self.username not in users:
users[self.username] = dict() users[self.username] = dict()
users[self.username]["auth_keys"] = self.__auth_keys users[self.username]["authorized_keys"] = self.__auth_keys
users[self.username]["priv_keys"] = self.__priv_keys users[self.username]["credential_keys"] = self.__priv_keys
users[self.username]["keys"] = self.__ssh_keys users[self.username]["keys"] = self.__ssh_keys
users[self.username]["keypairs"] = self.__ssh_keypairs users[self.username]["keypairs"] = self.__ssh_keypairs
users[self.username]["preferred_priv_key"] = self.__priv_key_pref users[self.username]["preferred_priv_key"] = self.__priv_key_pref
@@ -170,14 +211,16 @@ class User:
raise KeyError raise KeyError
self.__auth_keys.append(public_key) self.__auth_keys.append(public_key)
self.ssh_authorized_keys.append(public_key.read_text())
self.__priv_keys.append(private_key) self.__priv_keys.append(private_key)
self.ssh_private_key_paths.append(str(private_key))
self.__ssh_keypairs = (*self.__ssh_keypairs, (self.__priv_keys.tail, self.__auth_keys.tail),) self.__ssh_keypairs = (*self.__ssh_keypairs, (self.__priv_keys.tail, self.__auth_keys.tail),)
self.__priv_key_pref = len(self.__ssh_keypairs) - 1 self.__priv_key_pref = len(self.__ssh_keypairs) - 1
self.ssh_private_key_path_pref = len(self.__ssh_keypairs) - 1
self.__update_sshd() self.__update_sshd()
self.__ssh_keypair_chosen = True self.__ssh_keypair_chosen = True
def get_keypair(self, preference: int): def get_keypair(self, preference: int):
if not self.__ssh_keypair_chosen: if not self.__ssh_keypair_chosen:
raise Exception raise Exception
@@ -194,7 +237,7 @@ class User:
return self.__ssh_keypairs[preference] return self.__ssh_keypairs[preference]
@property @property
def available_keys(self) -> SSHKeyCollection: def keys(self) -> SSHKeyCollection:
return self.__ssh_keys return self.__ssh_keys
@property @property
@@ -229,7 +272,7 @@ class User:
count += 1 count += 1
@property @property
def available_keypairs(self) -> tuple[SSHKey] | SSHKey | None: def keypairs(self) -> tuple[SSHKey] | SSHKey | None:
return self.__ssh_keypairs return self.__ssh_keypairs
@property @property
@@ -245,3 +288,12 @@ class User:
def credential_keys(self) -> SSHKeyCollection: def credential_keys(self) -> SSHKeyCollection:
return self.__priv_keys return self.__priv_keys
def __repr__(self) -> str:
return "%s(username=%r,password=%r,services=%r,uid=%r)" % (
self.__class__.__name__,
self.username,
self.password,
self.services,
self.id
)