| import torch | |
| import torch.nn as nn | |
| #simple test for the linear function to determine it's properties | |
| m = nn.Linear(20, 30) | |
| input = torch.randn(128, 20) | |
| output = m(input) | |
| print(output.size()) | |
| print(output) |
| import torch | |
| import torch.nn as nn | |
| #simple test for the linear function to determine it's properties | |
| m = nn.Linear(20, 30) | |
| input = torch.randn(128, 20) | |
| output = m(input) | |
| print(output.size()) | |
| print(output) |