speciesnet/model_info/
download_model.rs

1use std::{
2    fs::{File, create_dir_all},
3    io::{BufWriter, copy},
4};
5
6use directories::BaseDirs;
7use tracing::info;
8use zip::ZipArchive;
9
10use crate::error::Error;
11
12use super::ModelInfo;
13
14/// The directory for storing the downloaded model.
15const MODEL_DIRECTORY: &str = "speciesnet-rust/models/";
16/// The name of the folder for storing the downloaded model.
17const DEFAULT_MODEL_FOLDER: &str = "speciesnet-onnx-v4.0.0a";
18/// The file name of the default model.
19const DEFAULT_MODEL_FILE_NAME: &str = "speciesnet-onnx-v4.0.0a.zip";
20/// The url of the default model.
21const DEFAULT_MODEL_URL: &str =
22    "https://drive.usercontent.google.com/download?id=1dAGnnJvOiNku6i2Zv82p0Rzidtsk02fy&confirm";
23
24impl ModelInfo {
25    /// Constructs the [`ModelInfo`] instance from a default model url, this function will download the
26    /// file from the given url, then unzips it and put at the `speciesnet-rust/models/` folder.
27    pub fn from_default_url() -> Result<ModelInfo, Error> {
28        let base_dir = BaseDirs::new().ok_or_else(|| Error::BaseDirInitFailed)?;
29        let cache_dir = base_dir.cache_dir();
30
31        info!("Cache directory is {}.", cache_dir.display());
32        info!(
33            "Creating the directory {} for putting the model.",
34            MODEL_DIRECTORY
35        );
36
37        // make a directory in the retrieved cache folder of the model.
38        let model_dir = cache_dir.join(MODEL_DIRECTORY);
39        create_dir_all(&model_dir)?;
40
41        info!(
42            "Checking if the model has been downloaded at {}.",
43            model_dir.join(DEFAULT_MODEL_FOLDER).display()
44        );
45
46        // check if the model folder exists, or not.
47        let possible_model_path = model_dir.join(DEFAULT_MODEL_FOLDER);
48
49        if possible_model_path.exists() {
50            return ModelInfo::from_path(possible_model_path);
51        }
52
53        info!("Downloading the model from {}", DEFAULT_MODEL_URL);
54
55        // download the model from the url.
56        let response = ureq::get(DEFAULT_MODEL_URL).call()?;
57
58        if response.status() != 200 {
59            return Err(Error::RequestFailed(response.status().as_u16()));
60        }
61
62        let (_, body) = response.into_parts();
63        let mut body_reader = body.into_reader();
64
65        // This block forces a drop of the writer.
66        {
67            let model_zip_file_write = File::create(model_dir.join(DEFAULT_MODEL_FILE_NAME))?;
68            let mut writer = BufWriter::new(model_zip_file_write);
69
70            copy(&mut body_reader, &mut writer)?;
71        }
72
73        info!(
74            "Unzipping the contents inside {} into {}",
75            model_dir.join(DEFAULT_MODEL_FILE_NAME).display(),
76            model_dir.join(DEFAULT_MODEL_FOLDER).display(),
77        );
78
79        // Unzip the file and put it in the models folder.
80        let model_zip_file_read = File::open(model_dir.join(DEFAULT_MODEL_FILE_NAME))?;
81
82        let mut zip_file = ZipArchive::new(model_zip_file_read)?;
83        let extract_dir = model_dir.join(DEFAULT_MODEL_FOLDER);
84        zip_file.extract(&extract_dir)?;
85
86        ModelInfo::from_path(extract_dir)
87    }
88}