Source code for common.optim.ne.net.cpu.static.rnnfc
""":class:`.CPUStaticRNN` & its config."""fromdataclassesimportdataclassimporttorchfromjaxtypingimportFloat32fromtorchimportTensor,nn
[docs]@dataclassclassCPUStaticRNNFCConfig:"""Config values for :class:`CPUStaticRNNFC`. Args: input_size: Size of the input tensor. hidden_size: Size of the RNN hidden state. output_size: Size of the output tensor. """input_size:inthidden_size:intoutput_size:int
[docs]classCPUStaticRNNFC(nn.Module):"""CPU-running static architecture RNN w/ a final FC layer. Args: config """def__init__(self:"CPUStaticRNNFC",config:CPUStaticRNNFCConfig)->None:super().__init__()self.rnn=nn.RNNCell(input_size=config.input_size,hidden_size=config.hidden_size,)self.fc=nn.Linear(in_features=config.hidden_size,out_features=config.output_size,)self.h:Float32[Tensor," hidden_size"]=torch.zeros(size=(config.hidden_size,),)forparaminself.parameters():param.requires_grad=Falseparam.data=torch.zeros_like(param.data)
[docs]defreset(self:"CPUStaticRNNFC")->None:"""Resets the hidden state of the RNN."""self.h*=torch.zeros_like(self.h)