sgdrf.model
Implementation of SGDRF model.
Module Contents
- class sgdrf.model.SGDRF(config: sgdrf.sgdrf_config.SGDRFConfig)
Streaming Gaussian Dirichlet Random Field model.
- Parameters:
config (SGDRFConfig) – SGDRF config object
- xu
Sparse inducing points
- Type:
torch.Tensor
- dims
The number of spatiotemporal dimensions for this model
- Type:
int
- V
The number of observation types
- Type:
int
- K
The number of latent Gaussian processes
- Type:
int
- M
The number of inducing points
- Type:
int
- latent_shape
The Pyro shape of the latent Gaussian processes
- Type:
tuple[int]
- max_obs
The maximum number of possible simultaneous categorical observations
- Type:
int
- device
Pytorch device (e.g. torch.device(‘cuda’))
- Type:
torch.device
- dir_p
The Dirichlet hyperparameters for each entry in the word-topic matrix
- Type:
torch.Tensor
- jitter
Small jitter to add to covariance matrix diagonal
- Type:
float
- zero_loc
Tensor of all-zeros matching the inducing points
- Type:
torch.Tensor
- uloc
Inducing point mean variational parameter
- Type:
pyro.nn.module.PyroParam
- uscaletril
Inducing point lower triangular covariance Cholesky decomposition variational parameter
- Type:
pyro.nn.module.PyroParam
- word_topic_probs
Maximum a posteriori word-topic matrix variational parameter
- Type:
pyro.nn.module.PyroParam
- kernel
Gaussian process kernel
- Type:
pyro.contrib.gp.kernels.Kernel
- whiten
Whether the Gaussian process covariance matrix is whitened
- Type:
bool
- subsampler
Subsampling algorithm
- Type:
- num_particles
Number of parallel posterior latent samples to draw
- Type:
int
- objective
The objective function used during training
- Type:
pyro.infer.ELBO
- xs
All the locations of past observations
- Type:
torch.Tensor
- ws
All the past observations
- Type:
torch.Tensor
- optimizer
Stochastic gradient descent optimizer
- Type:
pyro.optim.PyroOptim
- svi
Stochastic variational inference helper object
- Type:
pyro.infer.SVI
- fail_on_nan_loss
Whether to raise an exception if training loss is NaN
- Type:
bool
- n_xs
The number of past observation locations
- Type:
bool
- Implementation of SGDRF model.
- static check_inputs(xs: torch.Tensor | None = None, ws: torch.Tensor | None = None)
Check that input location and observation sizes agree in shape.
- Parameters:
xs (Optional[torch.Tensor], optional) – The input locations, by default None
ws (Optional[torch.Tensor], optional) – The input observations, by default None
- forward(xs: torch.Tensor) torch.Tensor
Produce word probabilities at a given set of locations.
- Parameters:
xs (torch.Tensor) – Locations to generate observation probabilities at
- Returns:
Observation probabilities
- Return type:
torch.Tensor
- guide(xs: torch.Tensor, ws: torch.Tensor, subsample: torch.Tensor)
Run the stochastic variational inference approximate posterior.
- Parameters:
xs (torch.Tensor) – Locations of all observations
ws (torch.Tensor) – Observed categorical data
subsample (torch.Tensor) – Indices of past observations to use in this training step
- model(xs: torch.Tensor, ws: torch.Tensor, subsample: torch.Tensor) torch.Tensor
Run the stochastic variational inference prior and likelihood model.
- Parameters:
xs (torch.Tensor) – Locations of all observations
ws (torch.Tensor) – Observed categorical data
subsample (torch.Tensor) – Indices of past observations to use in this training step
- Returns:
The observations
- Return type:
torch.Tensor
- process_inputs(xs: torch.Tensor | None = None, ws: torch.Tensor | None = None)
Ingest new observations.
- Parameters:
xs (Optional[torch.Tensor], optional) – New locations, by default None
ws (Optional[torch.Tensor], optional) – New observations, by default None
- step() float
Take a single training step.
- Returns:
Training loss for this step
- Return type:
float
- Raises:
ValueError – If self.fail_on_nan_loss is True, raise an error if loss is NaN
- topic_prob(xs: torch.Tensor | None = None) torch.Tensor
Infer or predict topic probabilities.
- Parameters:
xs (Optional[torch.Tensor], optional) – The points to predict at, by default self.xs
- Returns:
Topic probabilities at each point
- Return type:
torch.Tensor
- word_prob(xs: torch.Tensor | None = None) torch.Tensor
Infer or predict word probabilities.
- Parameters:
xs (Optional[torch.Tensor], optional) – The points to predict at, by default self.xs
- Returns:
Word probabilities at each point
- Return type:
torch.Tensor
- word_topic_prob() pyro.nn.module.PyroParam
Get the inferred word-topic probability matrix.
- Returns:
The inferred word-topic probability matrix (MAP estimate)
- Return type:
pyro.nn.module.PyroParam
- sgdrf.model.EPSILON = 0.01
- sgdrf.model.TORCH_DEVICE_CPU