Skip to content

Commit c509dc9

Browse files
author
Глеб Брыкин
authored
Add files via upload
1 parent a10fdd0 commit c509dc9

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System;
2+
using System.IO;
3+
using System.Net;
4+
using System.Collections.Generic;
5+
6+
namespace System
7+
{
8+
9+
namespace AI
10+
{
11+
12+
public static partial class torchvision
13+
{
14+
15+
public static partial class models
16+
{
17+
18+
static models()
19+
{
20+
ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12;
21+
}
22+
23+
internal static Dictionary<string, torch.Tensor> __load_model(string path, string fname, string url, bool print = true)
24+
{
25+
if(File.Exists(path))
26+
{
27+
return (Dictionary<string, torch.Tensor>)torch.load(path);
28+
}
29+
else
30+
{
31+
if(print)
32+
{
33+
Console.WriteLine("Downloading " + fname + " from " + url);
34+
}
35+
(new WebClient()).DownloadFile(url, path + fname);
36+
return (Dictionary<string, torch.Tensor>)torch.load(path + fname);
37+
}
38+
}
39+
40+
}
41+
42+
}
43+
44+
}
45+
46+
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
using nn = System.AI.torch.nn;
2+
3+
namespace System
4+
{
5+
6+
namespace AI
7+
{
8+
9+
public static partial class torchvision
10+
{
11+
12+
public static partial class models
13+
{
14+
15+
public sealed class Fire : nn.Module
16+
{
17+
18+
public int inplanes;
19+
20+
public nn.Conv2d squeeze;
21+
22+
public nn.ReLU squeeze_activation;
23+
24+
public nn.Conv2d expand1x1;
25+
26+
public nn.ReLU expand1x1_activation;
27+
28+
public nn.Conv2d expand3x3;
29+
30+
public nn.ReLU expand3x3_activation;
31+
32+
public Fire(int inplanes,
33+
int squeeze_planes,
34+
int expand1x1_planes,
35+
int expand3x3_planes)
36+
{
37+
this.inplanes = inplanes;
38+
this.squeeze = new nn.Conv2d(inplanes,
39+
squeeze_planes,
40+
kernel_size: 1);
41+
this.squeeze_activation = new nn.ReLU();
42+
this.expand1x1 = new nn.Conv2d(squeeze_planes,
43+
expand1x1_planes,
44+
kernel_size: 1);
45+
this.expand1x1_activation = new nn.ReLU();
46+
this.expand3x3 = new nn.Conv2d(squeeze_planes,
47+
expand3x3_planes,
48+
kernel_size: 3,
49+
padding: 1);
50+
this.expand3x3_activation = new nn.ReLU();
51+
}
52+
53+
public override torch.Tensor forward(torch.Tensor x)
54+
{
55+
x = this.squeeze_activation.forward(this.squeeze.forward(x));
56+
return torch.cat(new torch.Tensor[]{
57+
this.expand1x1_activation.forward(this.expand1x1.forward(x)),
58+
this.expand3x3_activation.forward(this.expand3x3.forward(x))
59+
}, 1);
60+
}
61+
62+
}
63+
64+
public sealed class SqueezeNet : nn.Module
65+
{
66+
67+
public nn.Sequential features;
68+
69+
public nn.Sequential classifier;
70+
71+
public int num_classes;
72+
73+
public SqueezeNet(string version = "1_0",
74+
int num_classes = 1000)
75+
{
76+
this.num_classes = num_classes;
77+
if(version == "1_0")
78+
{
79+
this.features = new nn.Sequential(
80+
new nn.Conv2d(3, 96, kernel_size: 7, stride: 2),
81+
new nn.ReLU(),
82+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
83+
new Fire(96, 16, 64, 64),
84+
new Fire(128, 16, 64, 64),
85+
new Fire(128, 32, 128, 128),
86+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
87+
new Fire(256, 32, 128, 128),
88+
new Fire(256, 48, 192, 192),
89+
new Fire(384, 48, 192, 192),
90+
new Fire(384, 64, 256, 256),
91+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
92+
new Fire(512, 64, 256, 256)
93+
);
94+
}
95+
else
96+
{
97+
if(version == "1_1")
98+
{
99+
this.features = new nn.Sequential(
100+
new nn.Conv2d(3, 64, kernel_size: 3, stride: 2),
101+
new nn.ReLU(),
102+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
103+
new Fire(64, 16, 64, 64),
104+
new Fire(128, 16, 64, 64),
105+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
106+
new Fire(128, 32, 128, 128),
107+
new Fire(256, 32, 128, 128),
108+
new nn.MaxPool2d(kernel_size: 3, stride: 2),
109+
new Fire(256, 48, 192, 192),
110+
new Fire(384, 48, 192, 192),
111+
new Fire(384, 64, 256, 256),
112+
new Fire(512, 64, 256, 256)
113+
);
114+
}
115+
else
116+
{
117+
throw new torch.TorchException(string.Format("TorchException: Unsupported SqueezeNet version {0}: 1_0 or 1_1 expected.", version));
118+
}
119+
}
120+
var final_conv = new nn.Conv2d(512, this.num_classes, kernel_size: 1);
121+
this.classifier = new nn.Sequential(
122+
new nn.Dropout(p: 0.5),
123+
final_conv,
124+
new nn.ReLU(),
125+
new nn.AvgPool2d(13, 13)
126+
);
127+
}
128+
129+
public torch.Tensor forward(torch.Tensor x)
130+
{
131+
if((x.shape[1] != 3) || (x.shape[2] != 224) || (x.shape[3] != 224))
132+
{
133+
throw new ArgumentException("Unsupported image size: should be bx3x224x224.");
134+
}
135+
x = this.features.forward(x) as torch.Tensor;
136+
x = this.classifier.forward(x) as torch.Tensor;
137+
return torch.flatten(x, 1);
138+
}
139+
140+
}
141+
142+
public static SqueezeNet squeezenet1_0(bool pretrained = false, int num_classes = 1000)
143+
{
144+
var m = new SqueezeNet("1_0");
145+
if(pretrained)
146+
{
147+
m.load_state_dict(
148+
__load_model("",
149+
"squeezenet1_0.thn",
150+
"https://github.com/ColorfulSoft/System.AI/raw/master/Implementation/src/torchvision/models/thn/squeezenet1_0.thn")
151+
);
152+
}
153+
return m;
154+
}
155+
156+
public static SqueezeNet squeezenet1_1(bool pretrained = false, int num_classes = 1000)
157+
{
158+
var m = new SqueezeNet("1_1");
159+
if(pretrained)
160+
{
161+
m.load_state_dict(
162+
__load_model("",
163+
"squeezenet1_1.thn",
164+
"https://github.com/ColorfulSoft/System.AI/raw/master/Implementation/src/torchvision/models/thn/squeezenet1_1.thn")
165+
);
166+
}
167+
return m;
168+
}
169+
170+
}
171+
172+
}
173+
174+
}
175+
176+
}

0 commit comments

Comments
 (0)