1+ using nn = System . AI . torch . nn ;
2+
3+ namespace System
4+ {
5+
6+ namespace AI
7+ {
8+
9+ public static partial class torchvision
10+ {
11+
12+ public static partial class models
13+ {
14+
15+ public sealed class Fire : nn . Module
16+ {
17+
18+ public int inplanes ;
19+
20+ public nn . Conv2d squeeze ;
21+
22+ public nn . ReLU squeeze_activation ;
23+
24+ public nn . Conv2d expand1x1 ;
25+
26+ public nn . ReLU expand1x1_activation ;
27+
28+ public nn . Conv2d expand3x3 ;
29+
30+ public nn . ReLU expand3x3_activation ;
31+
32+ public Fire ( int inplanes ,
33+ int squeeze_planes ,
34+ int expand1x1_planes ,
35+ int expand3x3_planes )
36+ {
37+ this . inplanes = inplanes ;
38+ this . squeeze = new nn . Conv2d ( inplanes ,
39+ squeeze_planes ,
40+ kernel_size : 1 ) ;
41+ this . squeeze_activation = new nn . ReLU ( ) ;
42+ this . expand1x1 = new nn . Conv2d ( squeeze_planes ,
43+ expand1x1_planes ,
44+ kernel_size : 1 ) ;
45+ this . expand1x1_activation = new nn . ReLU ( ) ;
46+ this . expand3x3 = new nn . Conv2d ( squeeze_planes ,
47+ expand3x3_planes ,
48+ kernel_size : 3 ,
49+ padding : 1 ) ;
50+ this . expand3x3_activation = new nn . ReLU ( ) ;
51+ }
52+
53+ public override torch . Tensor forward ( torch . Tensor x )
54+ {
55+ x = this . squeeze_activation . forward ( this . squeeze . forward ( x ) ) ;
56+ return torch . cat ( new torch . Tensor [ ] {
57+ this . expand1x1_activation . forward ( this . expand1x1 . forward ( x ) ) ,
58+ this . expand3x3_activation . forward ( this . expand3x3 . forward ( x ) )
59+ } , 1 ) ;
60+ }
61+
62+ }
63+
64+ public sealed class SqueezeNet : nn . Module
65+ {
66+
67+ public nn . Sequential features ;
68+
69+ public nn . Sequential classifier ;
70+
71+ public int num_classes ;
72+
73+ public SqueezeNet ( string version = "1_0" ,
74+ int num_classes = 1000 )
75+ {
76+ this . num_classes = num_classes ;
77+ if ( version == "1_0" )
78+ {
79+ this . features = new nn . Sequential (
80+ new nn . Conv2d ( 3 , 96 , kernel_size : 7 , stride : 2 ) ,
81+ new nn . ReLU ( ) ,
82+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
83+ new Fire ( 96 , 16 , 64 , 64 ) ,
84+ new Fire ( 128 , 16 , 64 , 64 ) ,
85+ new Fire ( 128 , 32 , 128 , 128 ) ,
86+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
87+ new Fire ( 256 , 32 , 128 , 128 ) ,
88+ new Fire ( 256 , 48 , 192 , 192 ) ,
89+ new Fire ( 384 , 48 , 192 , 192 ) ,
90+ new Fire ( 384 , 64 , 256 , 256 ) ,
91+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
92+ new Fire ( 512 , 64 , 256 , 256 )
93+ ) ;
94+ }
95+ else
96+ {
97+ if ( version == "1_1" )
98+ {
99+ this . features = new nn . Sequential (
100+ new nn . Conv2d ( 3 , 64 , kernel_size : 3 , stride : 2 ) ,
101+ new nn . ReLU ( ) ,
102+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
103+ new Fire ( 64 , 16 , 64 , 64 ) ,
104+ new Fire ( 128 , 16 , 64 , 64 ) ,
105+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
106+ new Fire ( 128 , 32 , 128 , 128 ) ,
107+ new Fire ( 256 , 32 , 128 , 128 ) ,
108+ new nn . MaxPool2d ( kernel_size : 3 , stride : 2 ) ,
109+ new Fire ( 256 , 48 , 192 , 192 ) ,
110+ new Fire ( 384 , 48 , 192 , 192 ) ,
111+ new Fire ( 384 , 64 , 256 , 256 ) ,
112+ new Fire ( 512 , 64 , 256 , 256 )
113+ ) ;
114+ }
115+ else
116+ {
117+ throw new torch . TorchException ( string . Format ( "TorchException: Unsupported SqueezeNet version {0}: 1_0 or 1_1 expected." , version ) ) ;
118+ }
119+ }
120+ var final_conv = new nn . Conv2d ( 512 , this . num_classes , kernel_size : 1 ) ;
121+ this . classifier = new nn . Sequential (
122+ new nn . Dropout ( p : 0.5 ) ,
123+ final_conv ,
124+ new nn . ReLU ( ) ,
125+ new nn . AvgPool2d ( 13 , 13 )
126+ ) ;
127+ }
128+
129+ public torch . Tensor forward ( torch . Tensor x )
130+ {
131+ if ( ( x . shape [ 1 ] != 3 ) || ( x . shape [ 2 ] != 224 ) || ( x . shape [ 3 ] != 224 ) )
132+ {
133+ throw new ArgumentException ( "Unsupported image size: should be bx3x224x224." ) ;
134+ }
135+ x = this . features . forward ( x ) as torch . Tensor ;
136+ x = this . classifier . forward ( x ) as torch . Tensor ;
137+ return torch . flatten ( x , 1 ) ;
138+ }
139+
140+ }
141+
142+ public static SqueezeNet squeezenet1_0 ( bool pretrained = false , int num_classes = 1000 )
143+ {
144+ var m = new SqueezeNet ( "1_0" ) ;
145+ if ( pretrained )
146+ {
147+ m . load_state_dict (
148+ __load_model ( "" ,
149+ "squeezenet1_0.thn" ,
150+ "https://github.com/ColorfulSoft/System.AI/raw/master/Implementation/src/torchvision/models/thn/squeezenet1_0.thn" )
151+ ) ;
152+ }
153+ return m ;
154+ }
155+
156+ public static SqueezeNet squeezenet1_1 ( bool pretrained = false , int num_classes = 1000 )
157+ {
158+ var m = new SqueezeNet ( "1_1" ) ;
159+ if ( pretrained )
160+ {
161+ m . load_state_dict (
162+ __load_model ( "" ,
163+ "squeezenet1_1.thn" ,
164+ "https://github.com/ColorfulSoft/System.AI/raw/master/Implementation/src/torchvision/models/thn/squeezenet1_1.thn" )
165+ ) ;
166+ }
167+ return m ;
168+ }
169+
170+ }
171+
172+ }
173+
174+ }
175+
176+ }
0 commit comments