|
206 | 206 | "execution_count": 3, |
207 | 207 | "id": "dcdbe1ae-ea13-49cb-b5a3-3c2c78f91f2b", |
208 | 208 | "metadata": { |
| 209 | + "lines_to_next_cell": 2, |
209 | 210 | "tags": [] |
210 | 211 | }, |
211 | 212 | "outputs": [ |
|
224 | 225 | "import torch.nn as nn\n", |
225 | 226 | "import torch.nn.functional as F\n", |
226 | 227 | "\n", |
227 | | - "\n", |
228 | 228 | "class Net(nn.Module):\n", |
229 | 229 | " def __init__(self):\n", |
230 | 230 | " super().__init__()\n", |
|
257 | 257 | "cell_type": "code", |
258 | 258 | "execution_count": 4, |
259 | 259 | "id": "189d71c5-6556-4891-a382-0adbc8f80d30", |
260 | | - "metadata": {}, |
| 260 | + "metadata": { |
| 261 | + "lines_to_next_cell": 2 |
| 262 | + }, |
261 | 263 | "outputs": [ |
262 | 264 | { |
263 | 265 | "name": "stdout", |
|
274 | 276 | "\n", |
275 | 277 | "transform = transforms.Compose(\n", |
276 | 278 | " [transforms.ToTensor(),\n", |
277 | | - " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n" |
| 279 | + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])" |
278 | 280 | ] |
279 | 281 | }, |
280 | 282 | { |
281 | 283 | "cell_type": "code", |
282 | 284 | "execution_count": 5, |
283 | 285 | "id": "3d8f233e-495c-450c-a445-46d295ba7461", |
284 | 286 | "metadata": { |
| 287 | + "lines_to_next_cell": 2, |
285 | 288 | "tags": [] |
286 | 289 | }, |
287 | 290 | "outputs": [ |
|
301 | 304 | "\n", |
302 | 305 | "batch_size = 4\n", |
303 | 306 | "\n", |
| 307 | + "\n", |
304 | 308 | "def get_dataloader(is_training, transform):\n", |
305 | | - " \n", |
| 309 | + "\n", |
306 | 310 | " if is_training:\n", |
307 | 311 | " trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", |
308 | | - " download=True, transform=transform)\n", |
| 312 | + " download=True, transform=transform)\n", |
309 | 313 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", |
310 | 314 | " shuffle=True, num_workers=2)\n", |
311 | 315 | " return trainloader\n", |
312 | 316 | " else:\n", |
313 | 317 | " testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", |
314 | | - " download=True, transform=transform)\n", |
| 318 | + " download=True, transform=transform)\n", |
315 | 319 | " testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", |
316 | 320 | " shuffle=False, num_workers=2)\n", |
317 | 321 | " return testloader " |
|
347 | 351 | "import torch.nn as nn\n", |
348 | 352 | "import torch.optim as optim\n", |
349 | 353 | "\n", |
350 | | - "\n", |
351 | | - "def train(net,trainloader):\n", |
| 354 | + "def train(net, trainloader):\n", |
352 | 355 | " criterion = nn.CrossEntropyLoss()\n", |
353 | 356 | " optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n", |
354 | 357 | "\n", |
|
390 | 393 | "execution_count": 7, |
391 | 394 | "id": "0b9764a8-674c-42ae-ad4b-f2dea027bdbf", |
392 | 395 | "metadata": { |
| 396 | + "lines_to_next_cell": 2, |
393 | 397 | "tags": [] |
394 | 398 | }, |
395 | 399 | "outputs": [ |
|
542 | 546 | "\n", |
543 | 547 | "import torch\n", |
544 | 548 | "\n", |
545 | | - "\n", |
546 | 549 | "def test(net, testloader):\n", |
547 | 550 | " correct = 0\n", |
548 | 551 | " total = 0\n", |
549 | | - " \n", |
| 552 | + "\n", |
550 | 553 | " with torch.no_grad():\n", |
551 | 554 | " for data in testloader:\n", |
552 | 555 | " images, labels = data\n", |
|
555 | 558 | " total += labels.size(0)\n", |
556 | 559 | " correct += (predicted == labels).sum().item()\n", |
557 | 560 | "\n", |
558 | | - " print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')\n", |
559 | | - " " |
| 561 | + " print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')\n" |
560 | 562 | ] |
561 | 563 | }, |
562 | 564 | { |
563 | 565 | "cell_type": "code", |
564 | 566 | "execution_count": 3, |
565 | 567 | "id": "fb49aef2-9fb5-4e74-83d2-9da935e07648", |
566 | | - "metadata": {}, |
| 568 | + "metadata": { |
| 569 | + "lines_to_next_cell": 2 |
| 570 | + }, |
567 | 571 | "outputs": [ |
568 | 572 | { |
569 | 573 | "name": "stdout", |
|
678 | 682 | "import torch\n", |
679 | 683 | "from PIL import Image\n", |
680 | 684 | "\n", |
681 | | - "\n", |
682 | 685 | "def inference(net, transforms, filenames):\n", |
683 | 686 | " for fn in filenames:\n", |
684 | 687 | " with Image.open(fn) as im:\n", |
685 | 688 | " tim=transforms(im)\n", |
686 | 689 | " outputs=net(tim[None])\n", |
687 | | - " _, predictions = torch.max(outputs, 1)\n", |
| 690 | + " _, predictions=torch.max(outputs, 1)\n", |
688 | 691 | " print(fn, predictions[0].item())" |
689 | 692 | ] |
690 | 693 | }, |
|
0 commit comments