openpilot/frogpilot/assets/model_manager.py
2025-11-01 12:00:00 -07:00

544 lines
23 KiB
Python

#!/usr/bin/env python3
import json
import re
import requests
import shutil
import time
import urllib.parse
from pathlib import Path
from urllib.parse import quote_plus
from openpilot.common.basedir import BASEDIR
from openpilot.frogpilot.assets.download_functions import GITLAB_URL, download_file, get_remote_file_size, get_repository_url, handle_error, handle_request_error, verify_download
from openpilot.frogpilot.common.frogpilot_utilities import delete_file, extract_tar, load_json_file, update_json_file
from openpilot.frogpilot.common.frogpilot_variables import (
DEFAULT_MODEL, DEFAULT_MODEL_NAME, DEFAULT_MODEL_VERSION, MODELS_PATH, RESOURCES_REPO, TINYGRAD_FILES,
params, params_default, params_memory, update_frogpilot_toggles
)
VERSION = "v16"
VERSION_PATH = MODELS_PATH / "model_version"
CANCEL_DOWNLOAD_PARAM = "CancelModelDownload"
DOWNLOAD_PROGRESS_PARAM = "ModelDownloadProgress"
MODEL_DOWNLOAD_PARAM = "ModelToDownload"
MODEL_DOWNLOAD_ALL_PARAM = "DownloadAllModels"
UPDATE_TINYGRAD_PARAM = "UpdateTinygrad"
DEFAULT_TINYGRAD_SIZE = 87746736
TAR_FILE_NAME = f"Tinygrad_{VERSION}.tar.gz"
TINYGRAD_MODELD_PATH = Path(BASEDIR) / "frogpilot/tinygrad_modeld"
TINYGRAD_REPO_PATH = Path(BASEDIR) / "tinygrad_repo"
class ModelManager:
def __init__(self, boot_run=False):
self.downloading_model = False
self.available_models = (params.get("AvailableModels", encoding="utf-8") or "").split(",")
self.available_model_names = (params.get("AvailableModelNames", encoding="utf-8") or "").split(",")
self.model_versions = (params.get("ModelVersions", encoding="utf-8") or "").split(",")
self.model_sizes_path = MODELS_PATH / "model_sizes.json"
self.tinygrad_sizes_path = MODELS_PATH / "tinygrad_sizes.json"
self.model_sizes = load_json_file(self.model_sizes_path)
self.tinygrad_sizes = load_json_file(self.tinygrad_sizes_path)
self.session = requests.Session()
self.session.headers.update({"Accept-Language": "en"})
self.session.headers.update({"User-Agent": "frogpilot-model-downloader/1.0 (https://github.com/FrogAi/FrogPilot)"})
if boot_run:
self.copy_default_model()
self.validate_models()
def check_models(self, boot_run, repo_url):
downloaded_models = [
model for model in MODELS_PATH.iterdir()
if (MODELS_PATH / f"{model}.thneed").is_file() or all((MODELS_PATH / f"{model}_{filename}").is_file() for filename, _ in TINYGRAD_FILES)
]
for model_file in downloaded_models:
if not any(model in model_file.name for model in set(self.available_models)):
print(f"Removing outdated model: {model_file}")
delete_file(model_file)
for tmp_file in MODELS_PATH.glob("tmp*"):
if tmp_file.is_file():
delete_file(tmp_file)
if params.get("Model", encoding="utf-8").removesuffix("_default") not in self.available_models:
params.put("Model", params_default.get("Model", encoding="utf-8"))
if not (not boot_run and params.get_bool("AutomaticallyDownloadModels")):
return
model_sizes = self.fetch_all_model_sizes(repo_url)
if not model_sizes:
print("No model size data available. Skipping model checks...")
return
need_to_update_models = False
for model in self.available_models:
if self.is_tinygrad_model(model):
model_file = MODELS_PATH / f"{model}.thneed"
if not model_file.is_file():
need_to_update_models = True
continue
expected_size = model_sizes.get(model_file.name)
local_size = self.model_sizes.get(model_file.name)
if expected_size > 0 and local_size != expected_size:
print(f"Model {model} is outdated. Deleting {model_file}...")
delete_file(model_file)
need_to_update_models = True
else:
model_missing = False
model_outdated = False
for filename, _ in TINYGRAD_FILES:
expected_file = MODELS_PATH / f"{model}_{filename}"
if not expected_file.is_file():
model_missing = True
need_to_update_models = True
break
for filename, _ in TINYGRAD_FILES:
model_file = f"{model}_{filename}"
expected_size = model_sizes.get(model_file)
local_size = self.model_sizes.get(model_file)
if expected_size > 0 and local_size != expected_size:
model_outdated = True
need_to_update_models = True
break
if model_missing or model_outdated:
print(f"Model {model} is either missing required files or outdated. Deleting...")
for filename, _ in TINYGRAD_FILES:
delete_file(MODELS_PATH / f"{model}_{filename}")
if need_to_update_models:
params_memory.put_bool(MODEL_DOWNLOAD_ALL_PARAM, True)
def check_tinygrad(self, repo_url):
tinygrad_url = f"{repo_url}/Tinygrad/{TAR_FILE_NAME}"
expected_size = get_remote_file_size(tinygrad_url, self.session)
local_size = int(self.tinygrad_sizes.get(TAR_FILE_NAME, 0))
if expected_size > 0 and local_size != expected_size:
print(f"Tinygrad version {VERSION} is outdated, expected_size: {expected_size}, local_size: {local_size}, flagging for update...")
params.put_bool("TinygradUpdateAvailable", True)
def copy_default_model(self):
classic_default_model_path = MODELS_PATH / "wd-40.thneed"
source_path = Path(__file__).parents[1] / "classic_modeld/models/supercombo.thneed"
if source_path.is_file() and (not classic_default_model_path.is_file() or source_path.stat().st_size != classic_default_model_path.stat().st_size):
shutil.copyfile(source_path, classic_default_model_path)
print(f"Copied the classic default model from {source_path} to {classic_default_model_path}")
self.update_model_size(classic_default_model_path)
default_model_path = MODELS_PATH / "national-public-radio.thneed"
source_path = Path(__file__).parents[2] / "selfdrive/modeld/models/supercombo.thneed"
if source_path.is_file() and (not default_model_path.is_file() or source_path.stat().st_size != default_model_path.stat().st_size):
shutil.copyfile(source_path, default_model_path)
print(f"Copied the default model from {source_path} to {default_model_path}")
self.update_model_size(default_model_path)
for filename, description in TINYGRAD_FILES:
source = TINYGRAD_MODELD_PATH / "models" / filename
target = MODELS_PATH / f"{DEFAULT_MODEL}_{filename}"
if source.is_file() and (not target.is_file() or source.stat().st_size != target.stat().st_size):
shutil.copyfile(source, target)
print(f"Copied the tinygrad {description} from {source} to {target}")
def download_all_models(self):
repo_url = get_repository_url(self.session)
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
return
self.fetch_models(f"{repo_url}/Versions/model_names_{VERSION}.json", repo_url)
for model in self.available_models:
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM)
return
if self.is_tinygrad_model(model):
already_downloaded = (MODELS_PATH / f"{model}.thneed").is_file()
else:
already_downloaded = all((MODELS_PATH / f"{model}_{filename}").is_file() for filename, _ in TINYGRAD_FILES)
if already_downloaded:
continue
print(f"Model {model} is not downloaded. Preparing to download...")
params_memory.put(DOWNLOAD_PROGRESS_PARAM, f"Downloading \"{self.available_model_names[self.available_models.index(model)]}\"...")
self.download_model(model)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "All models downloaded!")
params_memory.remove(MODEL_DOWNLOAD_ALL_PARAM)
def download_model(self, model_to_download):
self.downloading_model = True
repo_url = get_repository_url(self.session)
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
if self.is_tinygrad_model(model_to_download):
model_path = MODELS_PATH / f"{model_to_download}.thneed"
model_url = f"{repo_url}/Models/{model_to_download}.thneed"
print(f"Downloading model: {model_to_download}")
download_file(CANCEL_DOWNLOAD_PARAM, model_path, DOWNLOAD_PROGRESS_PARAM, model_url, MODEL_DOWNLOAD_PARAM, self.session)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(model_path)
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
if verify_download(model_path, model_url, self.session):
print(f"Model {model_to_download} downloaded and verified successfully!")
self.update_model_size(model_path)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloaded!")
params_memory.remove(MODEL_DOWNLOAD_PARAM)
self.downloading_model = False
return
print(f"Verification failed for model {model_to_download}. Retrying from GitLab...")
fallback_url = f"{GITLAB_URL}/Models/{model_to_download}.thneed"
download_file(CANCEL_DOWNLOAD_PARAM, model_path, DOWNLOAD_PROGRESS_PARAM, fallback_url, MODEL_DOWNLOAD_PARAM, self.session)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(model_path)
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
if verify_download(model_path, fallback_url, self.session):
print(f"Model {model_to_download} downloaded and verified successfully from GitLab!")
self.update_model_size(model_path)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloaded!")
params_memory.remove(MODEL_DOWNLOAD_PARAM)
self.downloading_model = False
else:
handle_error(model_path, "Verification failed...", "GitLab verification failed", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
else:
all_model_sizes = self.fetch_all_model_sizes(repo_url) or {}
tinygrad_filenames = [f"{model_to_download}_{file_key}" for file_key, _ in TINYGRAD_FILES]
file_sizes = []
file_sources = []
missing = [name for name in tinygrad_filenames if int(all_model_sizes.get(name, 0)) <= 0]
if missing:
handle_error(None, "Missing size metadata...", f"Sizes not found for: {', '.join(missing)}...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
for filename in tinygrad_filenames:
primary_url = f"{repo_url}/Models/compiled/{filename}"
file_size = int(all_model_sizes.get(filename, 0))
file_sizes.append(file_size)
file_sources.append((primary_url, None))
downloaded_offset_bytes = 0
known_file_sizes = [size for size in file_sizes if size > 0]
total_model_bytes = sum(known_file_sizes) if len(known_file_sizes) == len(file_sizes) else 0
for (file_key, description), part_bytes, (primary_url, fallback_url) in zip(TINYGRAD_FILES, file_sizes, file_sources):
filename = f"{model_to_download}_{file_key}"
model_path = MODELS_PATH / filename
print(f"Downloading {description} for model: {model_to_download}")
download_file(CANCEL_DOWNLOAD_PARAM, model_path, DOWNLOAD_PROGRESS_PARAM, primary_url, MODEL_DOWNLOAD_PARAM, self.session, offset_bytes=downloaded_offset_bytes, total_bytes=total_model_bytes)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(model_path)
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
if verify_download(model_path, primary_url, self.session):
print(f"{description.capitalize()} for {model_to_download} downloaded and verified successfully!")
if total_model_bytes:
downloaded_offset_bytes += part_bytes
continue
print(f"Verification failed for {filename}. Retrying from GitLab...")
fallback_url = f"{GITLAB_URL}/Models/compiled/{filename}"
download_file(CANCEL_DOWNLOAD_PARAM, model_path, DOWNLOAD_PROGRESS_PARAM, fallback_url, MODEL_DOWNLOAD_PARAM, self.session, offset_bytes=downloaded_offset_bytes, total_bytes=total_model_bytes)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(model_path)
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
if verify_download(model_path, fallback_url, self.session):
print(f"{description.capitalize()} for {model_to_download} downloaded and verified successfully from GitLab!")
if total_model_bytes:
downloaded_offset_bytes += part_bytes
else:
handle_error(model_path, "Verification failed...", f"GitLab verification failed for {filename}", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
self.downloading_model = False
return
print(f"Updating model sizes for {model_to_download}...")
for filename, _ in TINYGRAD_FILES:
file_path = MODELS_PATH / f"{model_to_download}_{filename}"
self.update_model_size(file_path)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloaded!")
params_memory.remove(MODEL_DOWNLOAD_PARAM)
self.downloading_model = False
def fetch_all_model_sizes(self, repo_url):
is_github = "github" in repo_url
is_gitlab = "gitlab" in repo_url
repo_encoded = quote_plus(RESOURCES_REPO)
model_sizes = {}
try:
def fetch_dir_sizes(api_url):
sizes = {}
print(f"Fetching model metadata: {api_url}")
response = self.session.get(api_url, timeout=10)
response.raise_for_status()
content = response.json()
model_files = [file for file in content if "." in file["name"]]
if is_github:
for file in model_files:
sizes[file["name"]] = file.get("size", 0)
else:
for file in model_files:
file_path = quote_plus(file["path"])
metadata_url = f"https://gitlab.com/api/v4/projects/{repo_encoded}/repository/files/{file_path}/raw?ref=Models"
head_response = self.session.head(metadata_url, timeout=10)
if head_response.ok:
sizes[file["name"]] = int(head_response.headers.get("content-length", 0))
return sizes
if is_github:
top_api_url = f"https://api.github.com/repos/{RESOURCES_REPO}/contents?ref=Models"
version_api_url = f"https://api.github.com/repos/{RESOURCES_REPO}/contents/compiled?ref=Models"
elif is_gitlab:
top_api_url = f"https://gitlab.com/api/v4/projects/{repo_encoded}/repository/tree?ref=Models"
version_api_url = f"https://gitlab.com/api/v4/projects/{repo_encoded}/repository/tree?path=compiled&ref=Models"
else:
print(f"Unsupported repository URL: {repo_url}")
return model_sizes
model_sizes.update(fetch_dir_sizes(top_api_url))
model_sizes.update(fetch_dir_sizes(version_api_url))
return model_sizes
except requests.exceptions.RequestException as e:
handle_request_error(f"Failed to fetch model sizes from {'GitHub' if is_github else 'GitLab'}: {e}", None, None, None)
return {}
def fetch_models(self, url, repo_url, boot_run=False):
try:
response = self.session.get(url, timeout=10)
response.raise_for_status()
model_info = response.json().get("models", [])
if model_info:
self.update_model_params(model_info)
self.check_models(boot_run, repo_url)
self.check_tinygrad(repo_url)
except Exception as exception:
handle_request_error(exception, None, None, None)
return []
def is_tinygrad_model(self, model):
return self.model_versions[self.available_models.index(model)] in {"v1", "v2", "v3", "v4", "v5", "v6"}
def update_model_params(self, model_info):
self.available_models = [model["id"] for model in model_info]
self.available_model_names = [model["name"] for model in model_info]
self.model_versions = [model["version"] for model in model_info]
params.put("AvailableModels", ",".join(self.available_models))
params.put("AvailableModelNames", ",".join(self.available_model_names))
params.put("ModelVersions", ",".join(self.model_versions))
print("Models list updated successfully!")
def update_models(self, boot_run):
if self.downloading_model:
return
repo_url = get_repository_url(self.session)
if repo_url is None:
print("GitHub and GitLab are offline...")
return
self.fetch_models(f"{repo_url}/Versions/model_names_{VERSION}.json", repo_url, boot_run)
def update_model_size(self, file_path):
self.model_sizes[file_path.name] = file_path.stat().st_size
update_json_file(self.model_sizes_path, self.model_sizes)
print(f"Updated size for {file_path.name} in {self.model_sizes_path.name}")
def update_tinygrad_size(self, file_path):
self.tinygrad_sizes[TAR_FILE_NAME] = file_path.stat().st_size
update_json_file(self.tinygrad_sizes_path, self.tinygrad_sizes)
print(f"Updated size for {TAR_FILE_NAME} in {self.tinygrad_sizes_path.name}")
def update_tinygrad(self):
repo_url = get_repository_url(self.session)
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", None, None)
return
primary_url = f"{repo_url}/Tinygrad/{TAR_FILE_NAME}"
fallback_url = f"https://gitlab.com/{RESOURCES_REPO}/-/raw/Tinygrad/{TAR_FILE_NAME}"
tinygrad_tar_path = Path("/data/tmp/tinygrad.tar.gz")
try:
print(f"Attempting to download tinygrad from {primary_url}...")
download_file(CANCEL_DOWNLOAD_PARAM, tinygrad_tar_path, DOWNLOAD_PROGRESS_PARAM, primary_url, UPDATE_TINYGRAD_PARAM, self.session)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(tinygrad_tar_path)
handle_error(None, "Tinygrad update cancelled...", "Tinygrad update cancelled...", UPDATE_TINYGRAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
params_memory.remove("CancelModelDownload")
return
if not verify_download(tinygrad_tar_path, primary_url, self.session):
print(f"Verification failed for {primary_url}. Retrying from GitLab...")
download_file(CANCEL_DOWNLOAD_PARAM, tinygrad_tar_path, DOWNLOAD_PROGRESS_PARAM, fallback_url, UPDATE_TINYGRAD_PARAM, self.session)
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
delete_file(tinygrad_tar_path)
handle_error(None, "Tinygrad update cancelled...", "Tinygrad update cancelled...", UPDATE_TINYGRAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
params_memory.remove("CancelModelDownload")
return
if not verify_download(tinygrad_tar_path, fallback_url, self.session):
handle_error(tinygrad_tar_path, "Verification Failed", "Tinygrad verification failed", UPDATE_TINYGRAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
return
print("Tinygrad downloaded successfully! Proceeding with installation...")
self.update_tinygrad_size(tinygrad_tar_path)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Installing...")
print("Deleting old tinygrad directories...")
delete_file(TINYGRAD_MODELD_PATH)
print(f"Removed {TINYGRAD_MODELD_PATH}")
delete_file(TINYGRAD_REPO_PATH)
print(f"Removed {TINYGRAD_REPO_PATH}")
extract_tar(tinygrad_tar_path, Path(BASEDIR))
print("Tinygrad update completed successfully!")
params.put_bool("TinygradUpdateAvailable", False)
params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Updated!")
params_memory.remove(UPDATE_TINYGRAD_PARAM)
self.update_tinygrad_models(repo_url)
except Exception as exception:
handle_error(tinygrad_tar_path, "Update Failed", f"An unexpected error occurred: {exception}", UPDATE_TINYGRAD_PARAM, DOWNLOAD_PROGRESS_PARAM)
def update_tinygrad_models(self, repo_url=None):
print("Updating old Tinygrad models...")
installed_tinygrad_models = set()
for filename, _ in TINYGRAD_FILES:
suffix = f"_{filename}"
for file_path in MODELS_PATH.glob(f"*{suffix}"):
model_name = file_path.name.rsplit(suffix, 1)[0]
if model_name in set(self.available_models):
installed_tinygrad_models.add(model_name)
delete_file(file_path)
self.copy_default_model()
update_frogpilot_toggles()
if repo_url is None:
return
current_model = params.get("Model", encoding="utf-8").removesuffix("_default")
models_to_redownload = [current_model]
models_to_redownload += [model for model in sorted(installed_tinygrad_models) if model != current_model]
if DEFAULT_MODEL in models_to_redownload:
models_to_redownload.remove(DEFAULT_MODEL)
if models_to_redownload:
print(f"Redownloading the following models: {', '.join(models_to_redownload)}")
self.fetch_models(f"{repo_url}/Versions/model_names_{VERSION}.json", repo_url, boot_run=True)
for model in models_to_redownload:
if params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM)
return
params_memory.put(DOWNLOAD_PROGRESS_PARAM, f"Downloading \"{self.available_model_names[self.available_models.index(model)]}\"...")
self.download_model(model)
else:
print("No previously installed tinygrad models to redownload")
update_frogpilot_toggles()
def validate_models(self):
current = params.get("Model", encoding="utf-8")
default = params_default.get("Model", encoding="utf-8")
if current.endswith("_default") and current != default:
print(f"Model '{current}' does not match default '{default}', resetting...")
params.put("Model", default)
if VERSION_PATH.is_file():
version_name = VERSION_PATH.read_text().strip()
if version_name != VERSION or int(self.tinygrad_sizes.get(TAR_FILE_NAME, 0)) == 0:
self.update_tinygrad_models()
self.tinygrad_sizes[TAR_FILE_NAME] = DEFAULT_TINYGRAD_SIZE
update_json_file(self.tinygrad_sizes_path, self.tinygrad_sizes)
print(f"Updated size for {TAR_FILE_NAME} in {self.tinygrad_sizes_path.name}")
params.remove("TinygradUpdateAvailable")
VERSION_PATH.write_text(VERSION)
print(f"Updated {VERSION_PATH} to {VERSION}")
if len(self.available_models) != len(self.available_model_names) or len(self.available_models) != len(self.model_versions):
print("Model lists are out of sync. Resetting parameters...")
self.available_models = DEFAULT_MODEL
self.available_model_names = DEFAULT_MODEL_NAME
self.model_versions = DEFAULT_MODEL_VERSION
params.put("AvailableModels", self.available_models)
params.put("AvailableModelNames", self.available_model_names)
params.put("ModelVersions", self.model_versions)