Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified TicTacToe/policy_p1
Binary file not shown.
Binary file modified TicTacToe/policy_p2
Binary file not shown.
255 changes: 64 additions & 191 deletions TicTacToe/tic-tac-toe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,12 +23,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"BOARD_ROWS = 3\n",
"BOARD_COLS = 3"
"BOARD_COLS = 3\n",
"train = False"
]
},
{
Expand All @@ -44,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -83,11 +84,10 @@
" # diagonal\n",
" diag_sum1 = sum([self.board[i, i] for i in range(BOARD_COLS)])\n",
" diag_sum2 = sum([self.board[i, BOARD_COLS-i-1] for i in range(BOARD_COLS)])\n",
" diag_sum = max(diag_sum1, diag_sum2)\n",
" if diag_sum == 3:\n",
" if diag_sum1 == 3 or diag_sum2 == 3:\n",
" self.isEnd = True\n",
" return 1\n",
" if diag_sum == -3:\n",
" if diag_sum1 == -3 or diag_sum2 == -3:\n",
" self.isEnd = True\n",
" return -1\n",
" \n",
Expand Down Expand Up @@ -222,15 +222,18 @@
" if self.board[i, j] == -1:\n",
" token = 'o'\n",
" if self.board[i, j] == 0:\n",
" token = ' '\n",
" token = str(j+1 + i*3)\n",
"\n",
"\n",
"\n",
" out += token + ' | '\n",
" print(out)\n",
" print('-------------') "
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -294,21 +297,33 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"class HumanPlayer:\n",
" def __init__(self, name):\n",
" self.name = name \n",
" \n",
" def cell_to_row_col(self,cell_number):\n",
" if not 1 <= cell_number <= 9:\n",
" raise ValueError(\"Cell number must be between 1 and 9\")\n",
"\n",
" row = (cell_number - 1) // 3\n",
" col = (cell_number - 1) % 3\n",
" return row, col\n",
"\n",
" \n",
" def chooseAction(self, positions):\n",
" while True:\n",
" row = int(input(\"Input your action row:\"))\n",
" col = int(input(\"Input your action col:\"))\n",
" cell = int(input(\"Input your action:\"))\n",
" row,col = self.cell_to_row_col(cell)\n",
" action = (row, col)\n",
" if action in positions:\n",
" return action\n",
" \n",
"\n",
"\n",
" \n",
" # append a hash state\n",
" def addState(self, state):\n",
Expand All @@ -331,91 +346,28 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 59,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training...\n",
"Rounds 0\n",
"Rounds 1000\n",
"Rounds 2000\n",
"Rounds 3000\n",
"Rounds 4000\n",
"Rounds 5000\n",
"Rounds 6000\n",
"Rounds 7000\n",
"Rounds 8000\n",
"Rounds 9000\n",
"Rounds 10000\n",
"Rounds 11000\n",
"Rounds 12000\n",
"Rounds 13000\n",
"Rounds 14000\n",
"Rounds 15000\n",
"Rounds 16000\n",
"Rounds 17000\n",
"Rounds 18000\n",
"Rounds 19000\n",
"Rounds 20000\n",
"Rounds 21000\n",
"Rounds 22000\n",
"Rounds 23000\n",
"Rounds 24000\n",
"Rounds 25000\n",
"Rounds 26000\n",
"Rounds 27000\n",
"Rounds 28000\n",
"Rounds 29000\n",
"Rounds 30000\n",
"Rounds 31000\n",
"Rounds 32000\n",
"Rounds 33000\n",
"Rounds 34000\n",
"Rounds 35000\n",
"Rounds 36000\n",
"Rounds 37000\n",
"Rounds 38000\n",
"Rounds 39000\n",
"Rounds 40000\n",
"Rounds 41000\n",
"Rounds 42000\n",
"Rounds 43000\n",
"Rounds 44000\n",
"Rounds 45000\n",
"Rounds 46000\n",
"Rounds 47000\n",
"Rounds 48000\n",
"Rounds 49000\n"
]
}
],
"outputs": [],
"source": [
"p1 = Player(\"p1\")\n",
"p2 = Player(\"p2\")\n",
"\n",
"st = State(p1, p2)\n",
"print(\"training...\")\n",
"st.play(50000)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"p1.savePolicy()\n",
"p2.savePolicy()"
"\n",
"# If you want to re-train the agent change the variable at the beginning of the code to True\n",
"# The agent takes approximately 3 minutes to train\n",
"if train:\n",
" print(\"training...\")\n",
" st.play(50000)\n",
" p1.savePolicy()\n",
" p2.savePolicy()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -431,149 +383,70 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------\n",
"| | | | \n",
"-------------\n",
"| | x | | \n",
"| x | 2 | 3 | \n",
"-------------\n",
"| | | | \n",
"| 4 | 5 | 6 | \n",
"-------------\n",
"Input your action row:2\n",
"Input your action col:2\n",
"| 7 | 8 | 9 | \n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| | x | | \n",
"| x | 2 | 3 | \n",
"-------------\n",
"| | | o | \n",
"| 4 | o | 6 | \n",
"-------------\n",
"| 7 | 8 | 9 | \n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| | x | | \n",
"| x | 2 | 3 | \n",
"-------------\n",
"| | x | o | \n",
"| 4 | o | 6 | \n",
"-------------\n",
"Input your action row:0\n",
"Input your action col:1\n",
"| 7 | 8 | x | \n",
"-------------\n",
"| | o | | \n",
"-------------\n",
"| | x | | \n",
"| x | 2 | 3 | \n",
"-------------\n",
"| | x | o | \n",
"| 4 | o | 6 | \n",
"-------------\n",
"| 7 | o | x | \n",
"-------------\n",
"| | o | x | \n",
"-------------\n",
"| | x | | \n",
"| x | x | 3 | \n",
"-------------\n",
"| | x | o | \n",
"| 4 | o | 6 | \n",
"-------------\n",
"Input your action row:1\n",
"Input your action col:1\n",
"Input your action row:1\n",
"Input your action col:0\n",
"| 7 | o | x | \n",
"-------------\n",
"| | o | x | \n",
"-------------\n",
"| o | x | | \n",
"-------------\n",
"| | x | o | \n",
"-------------\n",
"-------------\n",
"| | o | x | \n",
"-------------\n",
"| o | x | | \n",
"-------------\n",
"| x | x | o | \n",
"-------------\n",
"computer wins!\n"
]
}
],
"source": [
"p1 = Player(\"computer\", exp_rate=0)\n",
"p1.loadPolicy(\"policy_p1\")\n",
"\n",
"p2 = HumanPlayer(\"human\")\n",
"\n",
"st = State(p1, p2)\n",
"st.play2()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-------------\n",
"| | | | \n",
"| 4 | o | 6 | \n",
"-------------\n",
"| | | x | \n",
"| 7 | o | x | \n",
"-------------\n",
"| | | | \n",
"-------------\n",
"Input your action row:2\n",
"Input your action col:2\n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| | | x | \n",
"-------------\n",
"| | | o | \n",
"-------------\n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| | x | x | \n",
"-------------\n",
"| | | o | \n",
"-------------\n",
"Input your action row:1\n",
"Input your action col:0\n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| o | x | x | \n",
"-------------\n",
"| | | o | \n",
"-------------\n",
"-------------\n",
"| | | | \n",
"-------------\n",
"| o | x | x | \n",
"-------------\n",
"| x | | o | \n",
"-------------\n",
"Input your action row:0\n",
"Input your action col:0\n",
"-------------\n",
"| o | | | \n",
"| x | x | o | \n",
"-------------\n",
"| o | x | x | \n",
"| x | o | 6 | \n",
"-------------\n",
"| x | | o | \n",
"| 7 | o | x | \n",
"-------------\n",
"-------------\n",
"| o | | x | \n",
"| x | x | o | \n",
"-------------\n",
"| o | x | x | \n",
"| x | o | 6 | \n",
"-------------\n",
"| x | | o | \n",
"| o | o | x | \n",
"-------------\n",
"computer wins!\n"
"human wins!\n"
]
}
],
Expand Down Expand Up @@ -604,7 +477,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.13.5"
}
},
"nbformat": 4,
Expand Down
Loading