Spaces:
Runtime error
Runtime error
valentin urena commited on
Update chess_board.py
Browse files- chess_board.py +16 -20
chess_board.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
import chess
|
| 9 |
import chess.svg
|
|
@@ -17,10 +17,10 @@ class Game:
|
|
| 17 |
self.counter = 0
|
| 18 |
self.arrow= None
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
def compile_model(self):
|
| 26 |
self.model.compile(sampler=self.sampler)
|
|
@@ -37,9 +37,9 @@ class Game:
|
|
| 37 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
| 38 |
response="",)
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
if self.make_move(gemma_move):
|
| 45 |
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
|
|
@@ -54,8 +54,7 @@ class Game:
|
|
| 54 |
return None
|
| 55 |
|
| 56 |
def gemma_moves(self):
|
| 57 |
-
|
| 58 |
-
# time.sleep(3)
|
| 59 |
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
|
| 60 |
return self.call_gemma(self.opening_moves[len(self.sequence)])
|
| 61 |
else:
|
|
@@ -64,23 +63,20 @@ class Game:
|
|
| 64 |
def player_moves(self, move):
|
| 65 |
return self.make_move(move)
|
| 66 |
|
| 67 |
-
# Function to display the board
|
| 68 |
def display_board(self):
|
| 69 |
-
|
| 70 |
-
# display(SVG(chess.svg.board(board=self.board)))
|
| 71 |
if self.arrow:
|
| 72 |
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
|
| 73 |
else:
|
| 74 |
board_svg = chess.svg.board(board=self.board)
|
| 75 |
-
# return svg2png(bytestring=board_svg)
|
| 76 |
return board_svg
|
| 77 |
|
| 78 |
-
|
| 79 |
def make_move(self, move):
|
|
|
|
| 80 |
try:
|
| 81 |
update = self.board.parse_san(move)
|
| 82 |
self.board.push(update)
|
| 83 |
-
# self.display_board()
|
| 84 |
self.sequence.append(move)
|
| 85 |
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
|
| 86 |
return True
|
|
@@ -93,10 +89,10 @@ class Game:
|
|
| 93 |
self.sequence = []
|
| 94 |
self.counter = 0
|
| 95 |
self.arrow = None
|
| 96 |
-
# self.board.reset
|
| 97 |
return self.display_board()
|
| 98 |
|
| 99 |
def generate_moves(self, move):
|
|
|
|
| 100 |
valid_move = self.player_moves(move)
|
| 101 |
if valid_move:
|
| 102 |
yield self.display_board(), f"You played: {move}"
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
|
| 3 |
|
| 4 |
+
import keras_nlp
|
| 5 |
+
import keras
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
import chess
|
| 9 |
import chess.svg
|
|
|
|
| 17 |
self.counter = 0
|
| 18 |
self.arrow= None
|
| 19 |
|
| 20 |
+
self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
| 21 |
+
self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
|
| 22 |
+
self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
|
| 23 |
+
self.compile_model()
|
| 24 |
|
| 25 |
def compile_model(self):
|
| 26 |
self.model.compile(sampler=self.sampler)
|
|
|
|
| 37 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
| 38 |
response="",)
|
| 39 |
|
| 40 |
+
output = self.model.generate(prompt, max_length=256)
|
| 41 |
|
| 42 |
+
gemma_move = output.split(' ')[-1].strip("'")
|
| 43 |
|
| 44 |
if self.make_move(gemma_move):
|
| 45 |
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
|
|
|
|
| 54 |
return None
|
| 55 |
|
| 56 |
def gemma_moves(self):
|
| 57 |
+
"""Calls Gemma to make a move, either self generated or from opening sequence"""
|
|
|
|
| 58 |
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
|
| 59 |
return self.call_gemma(self.opening_moves[len(self.sequence)])
|
| 60 |
else:
|
|
|
|
| 63 |
def player_moves(self, move):
|
| 64 |
return self.make_move(move)
|
| 65 |
|
|
|
|
| 66 |
def display_board(self):
|
| 67 |
+
"""Return SVG image of board state"""
|
|
|
|
| 68 |
if self.arrow:
|
| 69 |
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
|
| 70 |
else:
|
| 71 |
board_svg = chess.svg.board(board=self.board)
|
|
|
|
| 72 |
return board_svg
|
| 73 |
|
| 74 |
+
|
| 75 |
def make_move(self, move):
|
| 76 |
+
"""Checks to see if move is valid, if so pushes move to board state"""
|
| 77 |
try:
|
| 78 |
update = self.board.parse_san(move)
|
| 79 |
self.board.push(update)
|
|
|
|
| 80 |
self.sequence.append(move)
|
| 81 |
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
|
| 82 |
return True
|
|
|
|
| 89 |
self.sequence = []
|
| 90 |
self.counter = 0
|
| 91 |
self.arrow = None
|
|
|
|
| 92 |
return self.display_board()
|
| 93 |
|
| 94 |
def generate_moves(self, move):
|
| 95 |
+
"""Generator function for one full turn of chess moves"""
|
| 96 |
valid_move = self.player_moves(move)
|
| 97 |
if valid_move:
|
| 98 |
yield self.display_board(), f"You played: {move}"
|