From 4208b85a3d1693bb787d471e4c7a08ad1e056e24 Mon Sep 17 00:00:00 2001 From: tanayag Date: Tue, 23 Oct 2018 14:21:17 +0530 Subject: [PATCH] add --- test_one.ipynb | 720 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 720 insertions(+) create mode 100644 test_one.ipynb diff --git a/test_one.ipynb b/test_one.ipynb new file mode 100644 index 0000000..d221ba4 --- /dev/null +++ b/test_one.ipynb @@ -0,0 +1,720 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import Variable" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# generating dataset for discriminator\n", + "def get_dataset():\n", + " return torch.Tensor(np.random.normal(4,1.25,(1,100)))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# generating noise for generator\n", + "def make_noise():\n", + " return torch.Tensor(np.random.uniform(0,0,(1,50)))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class generator(nn.Module):\n", + " \n", + " def __init__(self,inp,out):\n", + " super(generator,self).__init__()\n", + " self.net = nn.Sequential(nn.Linear(inp,300),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(300,300),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(300,out)\n", + " )\n", + " \n", + " def forward(self,x):\n", + " x = self.net(x)\n", + " return x\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class discriminator(nn.Module):\n", + " \n", + " def __init__(self,inp,out):\n", + " super(discriminator,self).__init__()\n", + " self.net = nn.Sequential(nn.Linear(inp,300),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(300,300),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(300,out),\n", + " nn.Sigmoid()\n", + " )\n", + " \n", + " def forward(self,x):\n", + " x = self.net(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def stats(array):\n", + " array = array.detach().numpy()\n", + " return [np.mean(array),np.std(array)]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "gen = generator(50,100)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "generator(\n", + " (net): Sequential(\n", + " (0): Linear(in_features=50, out_features=300, bias=True)\n", + " (1): ReLU(inplace)\n", + " (2): Linear(in_features=300, out_features=300, bias=True)\n", + " (3): ReLU(inplace)\n", + " (4): Linear(in_features=300, out_features=100, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gen" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "x = make_noise()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "x1 = gen(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.0018841184, 0.036860824]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stats(x1)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "discrim = discriminator(100,1)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "discriminator(\n", + " (net): Sequential(\n", + " (0): Linear(in_features=100, out_features=300, bias=True)\n", + " (1): ReLU(inplace)\n", + " (2): Linear(in_features=300, out_features=300, bias=True)\n", + " (3): ReLU(inplace)\n", + " (4): Linear(in_features=300, out_features=1, bias=True)\n", + " (5): Sigmoid()\n", + " )\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "discrim" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "x2 = discrim(x1)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.5205]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x2" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "d_step = 10\n", + "g_step = 8\n", + "criteriond1 = nn.BCELoss()\n", + "optimizerd1 = optim.SGD(discrim.parameters(), lr=0.001, momentum=0.9)\n", + "\n", + "criteriond2 = nn.BCELoss()\n", + "optimizerd2 = optim.SGD(gen.parameters(), lr=0.001, momentum=0.9)\n", + "\n", + "printing_steps = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Epoch:', 0)\n", + "[0.0018841184, 0.036860824]\n", + "[0.001885428, 0.036860917]\n", + "[0.0018879148, 0.03686111]\n", + "[0.0018914625, 0.036861386]\n", + "[0.0018959647, 0.036861762]\n", + "[0.0019013253, 0.036862247]\n", + "[0.0019074584, 0.03686284]\n", + "[0.0019142872, 0.036863554]\n", + "\n", + "\n", + "\n", + "('Epoch:', 20)\n", + "[0.007975619, 0.038056422]\n", + "[0.008030179, 0.03808928]\n", + "[0.008086704, 0.03812305]\n", + "[0.008145271, 0.03815762]\n", + "[0.008205742, 0.038193222]\n", + "[0.008268582, 0.038229816]\n", + "[0.008333553, 0.038267855]\n", + "[0.008401263, 0.038307115]\n", + "\n", + "\n", + "\n", + "('Epoch:', 40)\n", + "[0.025739962, 0.07039434]\n", + "[0.025967699, 0.07082409]\n", + "[0.026195178, 0.071255475]\n", + "[0.026425453, 0.071689904]\n", + "[0.026660535, 0.072128154]\n", + "[0.026900362, 0.07257016]\n", + "[0.027144287, 0.07301512]\n", + "[0.027392678, 0.07346259]\n", + "\n", + "\n", + "\n", + "('Epoch:', 60)\n", + "[0.08450457, 0.14448352]\n", + "[0.08509155, 0.14511523]\n", + "[0.08568327, 0.14574906]\n", + "[0.08627997, 0.14638585]\n", + "[0.0868816, 0.14702617]\n", + "[0.08748844, 0.1476706]\n", + "[0.08809981, 0.1483197]\n", + "[0.08871749, 0.1489742]\n", + "\n", + "\n", + "\n", + "('Epoch:', 80)\n", + "[0.28840774, 0.27236685]\n", + "[0.29072142, 0.2731846]\n", + "[0.29304278, 0.27398548]\n", + "[0.2953758, 0.27477252]\n", + "[0.2977213, 0.27554947]\n", + "[0.30008143, 0.27632013]\n", + "[0.30245793, 0.27708793]\n", + "[0.30485335, 0.27785522]\n", + "\n", + "\n", + "\n", + "('Epoch:', 100)\n", + "[1.1765242, 0.47627425]\n", + "[1.1888257, 0.48015204]\n", + "[1.2012541, 0.48397332]\n", + "[1.2138283, 0.4878198]\n", + "[1.2265671, 0.49175856]\n", + "[1.2395148, 0.49582666]\n", + "[1.2526832, 0.5000789]\n", + "[1.2660835, 0.50456303]\n", + "\n", + "\n", + "\n", + "('Epoch:', 120)\n", + "[3.7521708, 1.1674445]\n", + "[3.7650826, 1.1632067]\n", + "[3.7773826, 1.1575049]\n", + "[3.7894604, 1.1513976]\n", + "[3.8016295, 1.1458958]\n", + "[3.8140812, 1.1418389]\n", + "[3.8267963, 1.1399107]\n", + "[3.839567, 1.14043]\n", + "\n", + "\n", + "\n", + "('Epoch:', 140)\n", + "[5.4304433, 0.95063484]\n", + "[5.4390655, 0.9454818]\n", + "[5.4475994, 0.9381844]\n", + "[5.45611, 0.9292164]\n", + "[5.4646544, 0.9190324]\n", + "[5.473289, 0.90804404]\n", + "[5.4820514, 0.8966615]\n", + "[5.490971, 0.88520867]\n", + "\n", + "\n", + "\n", + "('Epoch:', 160)\n", + "[6.554171, 1.0230845]\n", + "[6.5443854, 1.0079885]\n", + "[6.5336504, 0.9934966]\n", + "[6.5223026, 0.98004055]\n", + "[6.510661, 0.968119]\n", + "[6.4991875, 0.95708686]\n", + "[6.488184, 0.9479753]\n", + "[6.4779205, 0.94164]\n", + "\n", + "\n", + "\n", + "('Epoch:', 180)\n", + "[6.6743093, 1.1352086]\n", + "[6.674415, 1.1340271]\n", + "[6.6746874, 1.1323054]\n", + "[6.6751537, 1.1302996]\n", + "[6.675848, 1.1281751]\n", + "[6.6767945, 1.1260948]\n", + "[6.678016, 1.1241832]\n", + "[6.67953, 1.1225513]\n", + "\n", + "\n", + "\n", + "('Epoch:', 200)\n", + "[6.7470884, 1.0527247]\n", + "[6.7447414, 1.0537938]\n", + "[6.742381, 1.0543174]\n", + "[6.740056, 1.0544192]\n", + "[6.73781, 1.0542215]\n", + "[6.735685, 1.0538297]\n", + "[6.733716, 1.0533464]\n", + "[6.731936, 1.0528711]\n", + "\n", + "\n", + "\n", + "('Epoch:', 220)\n", + "[5.3740454, 1.1091442]\n", + "[5.3554115, 1.1135103]\n", + "[5.336738, 1.1160842]\n", + "[5.318118, 1.1171281]\n", + "[5.299648, 1.116884]\n", + "[5.281429, 1.1155953]\n", + "[5.2635565, 1.1134976]\n", + "[5.246128, 1.1108178]\n", + "\n", + "\n", + "\n", + "('Epoch:', 240)\n", + "[3.563117, 0.7772749]\n", + "[3.5572462, 0.7718341]\n", + "[3.5509315, 0.76590574]\n", + "[3.5443897, 0.7597089]\n", + "[3.5378225, 0.75346625]\n", + "[3.5314343, 0.74733096]\n", + "[3.5253472, 0.74146765]\n", + "[3.5196652, 0.73601395]\n", + "\n", + "\n", + "\n", + "('Epoch:', 260)\n", + "[3.762761, 0.44140473]\n", + "[3.766048, 0.44176507]\n", + "[3.769457, 0.44208202]\n", + "[3.772982, 0.442378]\n", + "[3.7766001, 0.44267818]\n", + "[3.7803097, 0.44300264]\n", + "[3.7840986, 0.44338825]\n", + "[3.7879653, 0.44384944]\n", + "\n", + "\n", + "\n", + "('Epoch:', 280)\n", + "[4.547168, 0.5041005]\n", + "[4.5512414, 0.50436765]\n", + "[4.5553017, 0.50460684]\n", + "[4.5593777, 0.5048857]\n", + "[4.5634933, 0.50526166]\n", + "[4.5676703, 0.505786]\n", + "[4.571926, 0.5064988]\n", + "[4.5762706, 0.5074317]\n", + "\n", + "\n", + "\n", + "('Epoch:', 300)\n", + "[4.714465, 0.60834765]\n", + "[4.7135544, 0.6082357]\n", + "[4.712129, 0.607926]\n", + "[4.71026, 0.6074627]\n", + "[4.7080126, 0.60689396]\n", + "[4.705456, 0.6062651]\n", + "[4.702648, 0.6056212]\n", + "[4.699642, 0.6050036]\n", + "\n", + "\n", + "\n", + "('Epoch:', 320)\n", + "[4.1752644, 0.49929926]\n", + "[4.1692877, 0.5012012]\n", + "[4.1635256, 0.5029179]\n", + "[4.157988, 0.5044922]\n", + "[4.1526837, 0.5059638]\n", + "[4.1476197, 0.50736845]\n", + "[4.1428027, 0.5087379]\n", + "[4.1382384, 0.51009995]\n", + "\n", + "\n", + "\n", + "('Epoch:', 340)\n", + "[3.9004126, 0.49423027]\n", + "[3.899606, 0.49115932]\n", + "[3.8988214, 0.48803943]\n", + "[3.8981094, 0.48495686]\n", + "[3.8975136, 0.4819867]\n", + "[3.897069, 0.47919214]\n", + "[3.8968034, 0.47662398]\n", + "[3.8967361, 0.47432172]\n", + "\n", + "\n", + "\n", + "('Epoch:', 360)\n", + "[3.845068, 0.45620793]\n", + "[3.8458064, 0.45662042]\n", + "[3.846777, 0.45696673]\n", + "[3.847946, 0.45725322]\n", + "[3.8492954, 0.4574869]\n", + "[3.850814, 0.45767534]\n", + "[3.8524847, 0.4578276]\n", + "[3.8542979, 0.45795205]\n", + "\n", + "\n", + "\n", + "('Epoch:', 380)\n", + "[4.4157295, 0.50118345]\n", + "[4.4171925, 0.5007066]\n", + "[4.418768, 0.5001312]\n", + "[4.420454, 0.49948642]\n", + "[4.4222555, 0.49879527]\n", + "[4.4241753, 0.49808162]\n", + "[4.4262133, 0.49736917]\n", + "[4.4283686, 0.49667883]\n", + "\n", + "\n", + "\n", + "('Epoch:', 400)\n", + "[4.5400233, 0.5564293]\n", + "[4.540456, 0.5564975]\n", + "[4.5409265, 0.5565482]\n", + "[4.5414352, 0.5565863]\n", + "[4.541979, 0.55661666]\n", + "[4.542558, 0.55664325]\n", + "[4.5431724, 0.55666983]\n", + "[4.54382, 0.55670017]\n", + "\n", + "\n", + "\n", + "('Epoch:', 420)\n", + "[4.692982, 0.5231943]\n", + "[4.692985, 0.5231961]\n", + "[4.6929874, 0.5231977]\n", + "[4.6929893, 0.5231991]\n", + "[4.6929913, 0.5232004]\n", + "[4.6929936, 0.52320147]\n", + "[4.692995, 0.5232025]\n", + "[4.6929965, 0.5232035]\n", + "\n", + "\n", + "\n", + "('Epoch:', 440)\n", + "[5.0075383, 0.43208402]\n", + "[5.007661, 0.43213037]\n", + "[5.007836, 0.4321994]\n", + "[5.008058, 0.43229046]\n", + "[5.0083237, 0.43240368]\n", + "[5.0086274, 0.43253857]\n", + "[5.008966, 0.43269572]\n", + "[5.0093365, 0.43287548]\n", + "\n", + "\n", + "\n", + "('Epoch:', 460)\n", + "[5.029313, 0.44687128]\n", + "[5.0294237, 0.44697717]\n", + "[5.0295234, 0.4470743]\n", + "[5.0296135, 0.4471632]\n", + "[5.0296936, 0.4472443]\n", + "[5.029766, 0.4473183]\n", + "[5.0298314, 0.44738564]\n", + "[5.02989, 0.44744688]\n", + "\n", + "\n", + "\n", + "('Epoch:', 480)\n", + "[5.0598216, 0.46177983]\n", + "[5.0598273, 0.461782]\n", + "[5.059832, 0.46178398]\n", + "[5.0598364, 0.46178573]\n", + "[5.059841, 0.4617873]\n", + "[5.0598445, 0.4617887]\n", + "[5.059849, 0.46178997]\n", + "[5.059851, 0.4617911]\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "for epoch in range(epochs):\n", + " \n", + " if epoch%printing_steps==0:\n", + " print(\"Epoch:\", epoch)\n", + " \n", + " # training discriminator\n", + " for d_i in range(d_step):\n", + " discrim.zero_grad()\n", + " \n", + " #real\n", + " data_d_real = Variable(get_dataset())\n", + " data_d_real_pred = discrim(data_d_real)\n", + " data_d_real_loss = criteriond1(data_d_real_pred,Variable(torch.ones(1,1)))\n", + " data_d_real_loss.backward()\n", + " \n", + " \n", + " #fake\n", + " data_d_noise = Variable(make_noise())\n", + " data_d_gen_out = gen(data_d_noise).detach()\n", + " data_fake_dicrim_out = discrim(data_d_gen_out)\n", + " data_fake_d_loss = criteriond1(data_fake_dicrim_out,Variable(torch.zeros(1,1)))\n", + " data_fake_d_loss.backward()\n", + " \n", + " optimizerd1.step()\n", + " \n", + " for g_i in range(g_step):\n", + " \n", + " gen.zero_grad()\n", + " \n", + " data_noise_gen = Variable(make_noise())\n", + " data_g_gen_out = gen(data_noise_gen)\n", + " data_g_dis_out = discrim(data_g_gen_out)\n", + " data_g_loss = criteriond2(data_g_dis_out,Variable(torch.ones(1,1)))\n", + " data_g_loss.backward()\n", + " \n", + " optimizerd2.step()\n", + " \n", + " if epoch%printing_steps==0:\n", + " print(stats(data_g_gen_out))\n", + " \n", + " if epoch%printing_steps==0:\n", + " print(\"\\n\\n\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "generator(\n", + " (net): Sequential(\n", + " (0): Linear(in_features=50, out_features=300, bias=True)\n", + " (1): ReLU(inplace)\n", + " (2): Linear(in_features=300, out_features=300, bias=True)\n", + " (3): ReLU(inplace)\n", + " (4): Linear(in_features=300, out_features=100, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gen" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# torch.save(gen.state_dict(),\"d_step_5000_g_step_2000\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "y = gen(make_noise())" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[5.0801945, 0.45474356]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stats(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}