@@ -290,6 +290,17 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
290290 print_entire_tree (self .root_node )
291291
292292
293+ # optional: prior value
294+ if self .config .set_prior_value :
295+ await self .websocket_step_start (step = 2 , step_name = "node_children_evaluation" , websocket = websocket )
296+ await self .node_children_evaluation (selected_node )
297+ tree_data = self ._get_tree_data ()
298+ if websocket :
299+ await self .websocket_tree_update (type = "tree_update_node_children_evaluation" , websocket = websocket , tree_data = tree_data )
300+ else :
301+ print ("after evaluation" )
302+ print_entire_tree (self .root_node )
303+
293304 # Step 3: simulation using the current node, (generate a path using the current node, and score the path)
294305 # TODO: implement simulation using openai
295306 print (f"{ GREEN } Step 3: Simulation{ RESET } " )
@@ -337,14 +348,15 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
337348 print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
338349 await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
339350 for node in path :
340- old_value = node .value
341- node .visits += 1
342- node .value += (score - node .value ) / node .visits
343- # consiste with lats backpropagation
344- #node.value = (node.value * (node.visits - 1) + score) / node.visits
345- print (f"Node { node .action } :" )
346- print (f" Visits: { node .visits } " )
347- print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
351+ if node != self .root_node :
352+ old_value = node .value
353+ node .visits += 1
354+ node .value += (score - node .value ) / node .visits
355+ # consiste with lats backpropagation
356+ #node.value = (node.value * (node.visits - 1) + score) / node.visits
357+ print (f"Node { node .action } :" )
358+ print (f" Visits: { node .visits } " )
359+ print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
348360 # add websocket information, just use websocket here
349361 # if websocket:
350362 # await websocket.send_json({
0 commit comments