Custom Weights¶
CustomWeights
¶
Bases: PlainResNetInference
Custom Weight Classifier that inherits from PlainResNetInference. This classifier can load any model that was based on the PytorchWildlife finetuning tool.
Source code in PytorchWildlife/models/classification/resnet_base/custom_weights.py
__init__(weights=None, class_names=None, device='cpu')
¶
Initialize the CustomWeights Classifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights
|
str
|
Path to the model weights. Defaults to None. |
None
|
class_names
|
list[str]
|
List of class names for the classifier. |
None
|
device
|
str
|
Device for model inference. Defaults to "cpu". |
'cpu'
|
Source code in PytorchWildlife/models/classification/resnet_base/custom_weights.py
results_generation(logits, img_ids, id_strip=None)
¶
Generate results for classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits
|
Tensor
|
Output tensor from the model. |
required |
img_ids
|
list[str]
|
List of image identifiers. |
required |
id_strip
|
str
|
Stripping string for better image ID saving. |
None
|
Returns:
Type | Description |
---|---|
list[dict]
|
list[dict]: List of dictionaries containing image ID, prediction, and confidence score. |