赞
踩
对于给定的标量时间概念 t,Time2Vec 的表示 t2v(t)是一个大小为 k+1的向量,定义如下:
-
- def t2v(tau, f, out_features, w, b, w0, b0, arg=None):
- if arg:
- v1 = f(torch.matmul(tau, w) + b, arg)
- else:
- v1 = f(torch.matmul(tau, w) + b)
- v2 = torch.matmul(tau, w0) + b0
- return torch.cat([v1, v2], 1)
- class SineActivation(nn.Module):
- def __init__(self, in_features, out_features):
- super(SineActivation, self).__init__()
- self.out_features = out_features
- self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
- self.b0 = nn.parameter.Parameter(torch.randn(in_features, 1))
- self.w = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))
- self.b = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))
- self.f = torch.sin
-
- def forward(self, tau):
- return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。