Commit b943b978 authored by Rafael Dätwyler's avatar Rafael Dätwyler
Browse files

remove actions as inputs in the RNN

parent 5dcd990e
......@@ -46,10 +46,9 @@ def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many
return - log_prob
class _MDRNNBase(nn.Module):
def __init__(self, latents, actions, hiddens, gaussians):
def __init__(self, latents, hiddens, gaussians):
super().__init__()
self.latents = latents
self.actions = actions
self.hiddens = hiddens
self.gaussians = gaussians
......@@ -61,14 +60,13 @@ class _MDRNNBase(nn.Module):
class MDRNN(_MDRNNBase):
""" MDRNN model for multi steps forward """
def __init__(self, latents, actions, hiddens, gaussians):
super().__init__(latents, actions, hiddens, gaussians)
self.rnn = nn.LSTM(latents + actions, hiddens)
def __init__(self, latents, hiddens, gaussians):
super().__init__(latents, hiddens, gaussians)
self.rnn = nn.LSTM(latents, hiddens)
def forward(self, actions, latents): # pylint: disable=arguments-differ
def forward(self, latents): # pylint: disable=arguments-differ
""" MULTI STEPS forward.
:args actions: (SEQ_LEN, BSIZE, ASIZE) torch tensor
:args latents: (SEQ_LEN, BSIZE, LSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, rs, ds, parameters of the GMM
......@@ -80,10 +78,9 @@ class MDRNN(_MDRNNBase):
- rs: (SEQ_LEN, BSIZE) torch tensor
- ds: (SEQ_LEN, BSIZE) torch tensor
"""
seq_len, bs = actions.size(0), actions.size(1)
seq_len, bs = latents.size(0), latents.size(1)
ins = torch.cat([actions, latents], dim=-1)
outs, _ = self.rnn(ins)
outs, _ = self.rnn(latents)
gmm_outs = self.gmm_linear(outs)
stride = self.gaussians * self.latents
......@@ -107,14 +104,13 @@ class MDRNN(_MDRNNBase):
class MDRNNCell(_MDRNNBase):
""" MDRNN model for one step forward """
def __init__(self, latents, actions, hiddens, gaussians):
super().__init__(latents, actions, hiddens, gaussians)
self.rnn = nn.LSTMCell(latents + actions, hiddens)
def __init__(self, latents, hiddens, gaussians):
super().__init__(latents, hiddens, gaussians)
self.rnn = nn.LSTMCell(latents, hiddens)
def forward(self, action, latent, hidden): # pylint: disable=arguments-differ
def forward(self, latent, hidden): # pylint: disable=arguments-differ
""" ONE STEP forward.
:args actions: (BSIZE, ASIZE) torch tensor
:args latents: (BSIZE, LSIZE) torch tensor
:args hidden: (BSIZE, RSIZE) torch tensor
......@@ -127,9 +123,8 @@ class MDRNNCell(_MDRNNBase):
- rs: (BSIZE) torch tensor
- ds: (BSIZE) torch tensor
"""
in_al = torch.cat([action, latent], dim=1)
next_hidden = self.rnn(in_al, hidden)
next_hidden = self.rnn(latent, hidden)
out_rnn = next_hidden[0]
out_full = self.gmm_linear(out_rnn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment