Implement a class to handle firmware management tasks

This commit is contained in:
1e9abhi1e10
2024-05-29 08:54:13 +05:30
parent 8f0cf5b3a3
commit 0bd5c22681
3 changed files with 182 additions and 48 deletions

View File

@@ -17,6 +17,10 @@ import asyncio
import ipaddress
import warnings
from typing import List, Optional, Protocol, Tuple, Type, TypeVar, Union
from pathlib import Path
import re
import httpx
import hashlib
from pyasic.config import MinerConfig
from pyasic.data import Fan, HashBoard, MinerData
@@ -523,6 +527,7 @@ class MinerProtocol(Protocol):
class BaseMiner(MinerProtocol):
def __init__(self, ip: str) -> None:
self.ip = ip
self.firmware_manager = FirmwareManager("http://feeds.braiins-os.com")
if self.expected_chips is None and self.raw_model is not None:
warnings.warn(
@@ -539,4 +544,116 @@ class BaseMiner(MinerProtocol):
self.ssh = self._ssh_cls(ip)
AnyMiner = TypeVar("AnyMiner", bound=BaseMiner)
AnyMiner = TypeVar("AnyMiner", bound=BaseMiner)
class FirmwareManager:
class FirmwareManager:
def __init__(self, remote_server_url: str):
"""
Initialize a FirmwareManager instance.
Args:
remote_server_url (str): The URL of the remote server to fetch firmware information.
"""
self.remote_server_url = remote_server_url
self.version_extractors = {}
# Register version extractor for braiins_os
self.register_version_extractor("braiins_os", self.extract_braiins_os_version)
def extract_braiins_os_version(self, firmware_file: Path) -> str:
"""
Extract the firmware version from the filename for braiins_os miners.
Args:
firmware_file (Path): The firmware file to extract the version from.
Returns:
str: The extracted firmware version.
Raises:
ValueError: If the version is not found in the filename.
"""
match = re.search(r"firmware_v(\d+\.\d+\.\d+)\.tar\.gz", firmware_file.name)
if match:
return match.group(1)
raise ValueError("Firmware version not found in the filename.")
async def get_latest_firmware_info(self) -> dict:
"""
Fetch the latest firmware information from the remote server.
Returns:
dict: The latest firmware information, including version and SHA256 hash.
Raises:
httpx.HTTPStatusError: If the HTTP request fails.
"""
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.remote_server_url}/latest")
response.raise_for_status()
return response.json()
async def download_firmware(self, url: str, file_path: Path):
"""
Download the firmware file from the specified URL and save it to the given file path.
Args:
url (str): The URL to download the firmware from.
file_path (Path): The file path to save the downloaded firmware.
Raises:
httpx.HTTPStatusError: If the HTTP request fails.
"""
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
with file_path.open("wb") as firmware_file:
firmware_file.write(response.content)
def calculate_sha256(self, file_path: Path) -> str:
"""
Calculate the SHA256 hash of the specified file.
Args:
file_path (Path): The file path of the file to calculate the hash for.
Returns:
str: The SHA256 hash of the file.
"""
sha256 = hashlib.sha256()
with file_path.open("rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
def register_version_extractor(self, miner_type: str, extractor_func):
"""
Register a custom firmware version extraction function for a specific miner type.
Args:
miner_type (str): The type of miner.
extractor_func (function): The function to extract the firmware version from the firmware file.
"""
self.version_extractors[miner_type] = extractor_func
def get_firmware_version(self, miner_type: str, firmware_file: Path) -> str:
"""
Extract the firmware version from the firmware file using the registered extractor function for the miner type.
Args:
miner_type (str): The type of miner.
firmware_file (Path): The firmware file to extract the version from.
Returns:
str: The firmware version.
Raises:
ValueError: If no extractor function is registered for the miner type or if the version is not found.
"""
if miner_type not in self.version_extractors:
raise ValueError(f"No version extractor registered for miner type: {miner_type}")
extractor_func = self.version_extractors[miner_type]
return extractor_func(firmware_file)

View File

@@ -4,7 +4,15 @@ from pyasic.ssh.base import BaseSSH
import logging
import httpx
from pathlib import Path
import os
import hashlib
from pyasic.miners.base import FirmwareManager
def calculate_sha256(file_path):
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
# Set up logging
logger = logging.getLogger(__name__)
@@ -24,6 +32,27 @@ class BOSMinerSSH(BaseSSH):
"""
super().__init__(ip)
self.pwd = settings.get("default_bosminer_ssh_password", "root")
self.firmware_manager = FirmwareManager()
def get_firmware_version(self, firmware_file):
"""
Extract the firmware version from the firmware file.
Args:
firmware_file (file): The firmware file to extract the version from.
Returns:
str: The firmware version.
"""
import re
# Extract the version from the filename using a regular expression
filename = firmware_file.name
match = re.search(r"firmware_v(\d+\.\d+\.\d+)\.tar\.gz", filename)
if match:
return match.group(1)
else:
raise ValueError("Firmware version not found in the filename.")
async def get_board_info(self):
"""
@@ -106,12 +135,14 @@ class BOSMinerSSH(BaseSSH):
"""
return await self.send_command("cat /sys/class/leds/'Red LED'/delay_off")
async def upgrade_firmware(self, file_location: str = None):
async def upgrade_firmware(self, file_location: str = None, custom_url: str = None, override_validation: bool = False):
"""
Upgrade the firmware of the BOSMiner device.
Args:
file_location (str): The local file path of the firmware to be uploaded. If not provided, the firmware will be downloaded from the internal server.
custom_url (str): Custom URL to download the firmware from.
override_validation (bool): Whether to override SHA256 validation.
Returns:
str: Confirmation message after upgrading the firmware.
@@ -126,38 +157,36 @@ class BOSMinerSSH(BaseSSH):
if cached_file_location.exists():
logger.info("Cached firmware file found. Checking version.")
# Compare cached firmware version with the latest version on the server
async with httpx.AsyncClient() as client:
response = await client.get("http://firmware.pyasic.org/latest")
response.raise_for_status()
latest_version = response.json().get("version")
cached_version = self._get_fw_ver()
latest_firmware_info = await self.firmware_manager.get_latest_firmware_info()
latest_version = latest_firmware_info.get("version")
latest_hash = latest_firmware_info.get("sha256")
cached_version = self.firmware_manager.get_firmware_version("braiins_os", cached_file_location)
if cached_version == latest_version:
logger.info("Cached firmware version matches the latest version. Using cached file.")
file_location = str(cached_file_location)
else:
logger.info("Cached firmware version does not match the latest version. Downloading new version.")
firmware_url = response.json().get("url")
firmware_url = custom_url or latest_firmware_info.get("url")
if not firmware_url:
raise ValueError("Firmware URL not found in the server response.")
async with httpx.AsyncClient() as client:
firmware_response = await client.get(firmware_url)
firmware_response.raise_for_status()
with cached_file_location.open("wb") as firmware_file:
firmware_file.write(firmware_response.content)
await self.firmware_manager.download_firmware(firmware_url, cached_file_location)
if not override_validation:
downloaded_hash = self.firmware_manager.calculate_sha256(cached_file_location)
if downloaded_hash != latest_hash:
raise ValueError("SHA256 hash validation failed for the downloaded firmware file.")
file_location = str(cached_file_location)
else:
logger.info("No cached firmware file found. Downloading new version.")
async with httpx.AsyncClient() as client:
response = await client.get("http://firmware.pyasic.org/latest")
response.raise_for_status()
firmware_url = response.json().get("url")
latest_firmware_info = await self.firmware_manager.get_latest_firmware_info()
firmware_url = custom_url or latest_firmware_info.get("url")
latest_hash = latest_firmware_info.get("sha256")
if not firmware_url:
raise ValueError("Firmware URL not found in the server response.")
async with httpx.AsyncClient() as client:
firmware_response = await client.get(firmware_url)
firmware_response.raise_for_status()
with cached_file_location.open("wb") as firmware_file:
firmware_file.write(firmware_response.content)
await self.firmware_manager.download_firmware(firmware_url, cached_file_location)
if not override_validation:
downloaded_hash = self.firmware_manager.calculate_sha256(cached_file_location)
if downloaded_hash != latest_hash:
raise ValueError("SHA256 hash validation failed for the downloaded firmware file.")
file_location = str(cached_file_location)
# Upload the firmware file to the BOSMiner device
@@ -174,6 +203,18 @@ class BOSMinerSSH(BaseSSH):
logger.info("Firmware upgrade process completed successfully.")
return result
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred during the firmware upgrade process: {e}")
raise
except FileNotFoundError as e:
logger.error(f"File not found during the firmware upgrade process: {e}")
raise
except ValueError as e:
logger.error(f"Validation error occurred during the firmware upgrade process: {e}")
raise
except OSError as e:
logger.error(f"OS error occurred during the firmware upgrade process: {e}")
raise
except Exception as e:
logger.error(f"An error occurred during the firmware upgrade process: {e}")
logger.error(f"An unexpected error occurred during the firmware upgrade process: {e}", exc_info=True)
raise

View File

@@ -1,24 +0,0 @@
import pytest
from unittest.mock import patch, mock_open
from pyasic.ssh.braiins_os import BOSMinerSSH
@pytest.fixture
def bosminer_ssh():
return BOSMinerSSH(ip="192.168.1.100")
@pytest.mark.asyncio
async def test_upgrade_firmware_with_valid_file_location(bosminer_ssh):
with patch("pyasic.ssh.braiins_os.os.path.exists") as mock_exists, \
patch("pyasic.ssh.braiins_os.open", mock_open(read_data="data")) as mock_file, \
patch("pyasic.ssh.braiins_os.requests.get") as mock_get, \
patch.object(bosminer_ssh, "send_command") as mock_send_command:
mock_exists.return_value = False
file_location = "/path/to/firmware.tar.gz"
result = await bosminer_ssh.upgrade_firmware(file_location=file_location)
mock_send_command.assert_any_call(f"scp {file_location} root@{bosminer_ssh.ip}:/tmp/firmware.tar.gz")
mock_send_command.assert_any_call("tar -xzf /tmp/firmware.tar.gz -C /tmp")
mock_send_command.assert_any_call("sh /tmp/upgrade_firmware.sh")
assert result is not None