speciesnet/model_info/
mod.rs

1use std::{
2    fs::read_to_string,
3    path::{Path, PathBuf},
4};
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::Error;
9
10#[cfg(feature = "download-model")]
11pub mod download_model;
12
13/// Possible types of the speciesnet model.
14#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ModelType {
17    AlwaysCrop,
18    FullImage,
19}
20
21/// Struct containing the model's information and where the files are.
22#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
23pub struct ModelInfo {
24    /// Version of the loaded model.
25    version: String,
26    /// Type of the loaded model.
27    #[serde(rename = "type")]
28    model_type: ModelType,
29    /// Path of where the classifier model is.
30    classifier: PathBuf,
31    /// Path of where the classifier labels is.
32    classifier_labels: PathBuf,
33    /// Path of where the detector model is.
34    detector: PathBuf,
35    /// Path of the taxonomy file.
36    taxonomy: PathBuf,
37    /// Path of the geofence file.
38    geofence: PathBuf,
39}
40
41impl ModelInfo {
42    /// Constructs the [`ModelInfo`] instance from a given folder of an extracted path of the
43    /// model.
44    pub fn from_path<P>(folder: P) -> Result<Self, Error>
45    where
46        P: AsRef<Path>,
47    {
48        let info_json_string = read_to_string(folder.as_ref().join("info.json"))?;
49        let info_json: Self = serde_json::from_str(&info_json_string)?;
50
51        let classifier_path = folder.as_ref().join(info_json.classifier());
52        let classifier_labels_path = folder.as_ref().join(info_json.classifier_labels());
53        let detector_path = folder.as_ref().join(info_json.detector());
54        let taxonomy_path = folder.as_ref().join(info_json.taxonomy());
55        let geofence_path = folder.as_ref().join(info_json.geofence());
56
57        Ok(Self {
58            version: info_json.version,
59            model_type: info_json.model_type,
60            classifier: classifier_path,
61            classifier_labels: classifier_labels_path,
62            detector: detector_path,
63            taxonomy: taxonomy_path,
64            geofence: geofence_path,
65        })
66    }
67
68    pub fn version(&self) -> &str {
69        &self.version
70    }
71
72    pub fn model_type(&self) -> ModelType {
73        self.model_type
74    }
75
76    pub fn classifier(&self) -> &Path {
77        &self.classifier
78    }
79
80    pub fn classifier_labels(&self) -> &Path {
81        &self.classifier_labels
82    }
83
84    pub fn detector(&self) -> &Path {
85        &self.detector
86    }
87
88    pub fn taxonomy(&self) -> &Path {
89        &self.taxonomy
90    }
91
92    pub fn geofence(&self) -> &Path {
93        &self.geofence
94    }
95}