update angle loss to MSE

Angles pi/2 and -pi/2 should lead to a large loss, since this is a difference by pi on the unit circle
Therefore we compute the absolute error of the "shorter" direction on the unit circle
return torch.mean(torch.abs(torch.atan2(torch.sin(a - b), torch.cos(a - b))))
return torch.mean(torch.square(torch.abs(torch.atan2(torch.sin(a - b), torch.cos(a - b)))))
