Changes in spotpython-0.16.9:
- xai.py: add new function viz_net to visualize the network architecture (linear nets).
- dimensions.py: add new function extract_linear_dims that extracts the input and output dimensions of the Linear layers in a PyTorch model.
Here is an example which demonstrates the usage of the viz_net()
function:
from spotpython.plot.xai import viz_net
from spotpython.utils.init import fun_control_init
from spotpython.data.diabetes import Diabetes
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
from spotpython.hyperparameters.values import (
get_default_hyperparameters_as_array, get_one_config_from_X)
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
_L_in=10
_L_out=1
_torchmetric="mean_squared_error"
fun_control = fun_control_init(
_L_in=_L_in,
_L_out=_L_out,
_torchmetric=_torchmetric,
data_set=Diabetes(),
core_model=NNLinearRegressor,
hyperdict=LightHyperDict)
X = get_default_hyperparameters_as_array(fun_control)
config = get_one_config_from_X(X, fun_control)
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
viz_net(net=model, device="cpu", show_attrs=True, show_saved=True, filename="model_architecture", format="png")