import java.util.*;
import java.io.*;
import index;
import state;
//******************************************************
// CLASS: Policy
//******************************************************
// This class handle the SARSA learning and update the 
// game policy.
//******************************************************

class policy {
	float state_action_value [][][];
	float alpha;
	float gama;
	int epsilon;
	index actions;
	index categories;
	index states;
	
	public policy(int random_policy, float horizon,float learning_ability ) 
		{
			alpha   = learning_ability;  
			epsilon = random_policy; 
			gama    = horizon;
			actions = new index("actions.txt");
			categories = new index("categories.txt");
			states = new index("states.txt");
			state_action_value = new float [states.max_index+1][categories.max_index+1][actions.max_index+1];
			for ( int i=0; i < states.max_index+1; i++ )
				for ( int j=0; j < categories.max_index+1; j++ )
					for ( int k=0; k < actions.max_index+1; k++ )
						state_action_value[i][j][k]=(float)0;
			//System.out.println("max state["+states.max_index+"], max category["+categories.max_index+"], max action["+actions.max_index+"].");
		}// of CTOR
		
		
	public void iteartion(game yatzee)
		{
			//System.out.println("Policy: start iteration");			
			state category_state = new state(yatzee.category[yatzee.player]);
			int category_index = categories.get(category_state.name());
			state new_state = new state();
			state new_action = new state();
			int new_state_index;
			if(yatzee.new_dice.numOfDie() == 0)
			// between rounds - update categories
				{
					//System.out.println("Policy: between rounds get new dice");
					new_state_index = states.get(new_state.name());
				}
			else 
				{
					new_state.merge(yatzee.dice_state);
					new_state.merge(yatzee.dice_action);
					new_state_index = states.get(new_state.name());
					Vector substates = yatzee.new_dice.substates();
					Enumeration enum = substates.elements();
					//System.out.println("Policy: turn["+yatzee.turn+"]");
					if ( yatzee.turn == 2 )
					//last turn - choose all	
						{
							//System.out.println("Policy: last turn choose all");
							new_action = yatzee.new_dice;
						}
					else if ( yatzee.r.dM(epsilon) == epsilon ) 
					// do random policy 
						{	
							//System.out.println("Policy: random policy");
							state random_action = new state();
							int state_to_go = yatzee.r.dM(substates.capacity());
							for ( int i=0; i<state_to_go; i++ )
								random_action = (state)enum.nextElement(); 						
							new_action = random_action;
						}
					else 
					// do greedy policy
						{ 
							//System.out.println("Policy: greedy policy");
							state current_action;
							float current_val, max_val = -1;
							while ( enum.hasMoreElements() ) 
								{
									current_action = (state)enum.nextElement();
									int current_action_index = actions.get(current_action.name());
									current_val = state_action_value[new_state_index][category_index][current_action_index];
									if ( current_val >= max_val ) 
										{
											if ( ( current_val > max_val) || ( yatzee.r.dM(substates.capacity()) < 3 ) )
												{
													//System.out.print("Policy: old max val["+max_val+"], new max val["+current_val+"] ==> current max action ");
													//current_action.print();
													max_val   = current_val;
													new_action    =  current_action ;
											  }
										}
								}
						}
				}
				// Update state action value 
				int action_index = actions.get(yatzee.dice_action.name());
				int state_index = states.get(yatzee.dice_state.name());
				int new_action_index = actions.get(new_action.name());	
				int imidiate_value = yatzee.immidiate_val();				
				state new_category_state = new state(yatzee.category[yatzee.player]);
				int new_category_index = categories.get(new_category_state.name());
				// print
				//System.out.print("policy:new_category_");
				//new_category_state.print();
				//System.out.println("--------------------------");
				//if (imidiate_value>0  ) {
				//	System.out.print("policy:state_");
				//	yatzee.dice_state.print();
				//	System.out.print("policy:action_");
				//	yatzee.dice_action.print();
				//}
				//System.out.print("policy:new_state_");
				//new_state.print();
				//System.out.print("policy:new_action_");
				//new_action.print();
				// compute
				float Qsa = state_action_value[state_index][category_index][action_index];
				//System.out.println("policy:Q["+state_index+"]["+category_index+"]["+action_index+"]="+Qsa);
				float QsaTag;
				if ( yatzee.one_game_end )
					{				
						//System.out.println("policy:Q' is Zero for end of game!");
						QsaTag = 0;
					}
				else	
					{	
						QsaTag = state_action_value[new_state_index][new_category_index][new_action_index];
						//System.out.println("policy:Q'["+new_state_index+"]["+new_category_index+"]["+new_action_index+"]="+QsaTag);
					}
				Qsa = Qsa + alpha*( imidiate_value + gama*(QsaTag-Qsa) );
				state_action_value[state_index][category_index][action_index] = Qsa;
				//System.out.println("policy:Q["+state_index+"]["+category_index+"]["+action_index+"]="+state_action_value[state_index][category_index][action_index]);
				//System.out.println("--------------------------");
				yatzee.dice_state = new_state;
				yatzee.dice_action = new_action;
		}// of iteration
		
		public void save() {
			DataOutputStream output;
			try  {
				output= new DataOutputStream( new FileOutputStream("action_state.save"));
				// save state-action table
				for ( int i=0; i < states.max_index+1; i++ )
					for ( int j=0; j < categories.max_index+1; j++ )
						for ( int k=0; k < actions.max_index+1; k++ )
							output.writeFloat(state_action_value[i][j][k]);
			}
			catch ( IOException e ) { 
				System.out.println( "Can't open save File- save blanck\n" );
			}
		}// of save

		public void load() {
			DataInputStream input;
			try  {
				input = new DataInputStream( new FileInputStream("action_state.save"));
				// save state-action table
				for ( int i=0; i < states.max_index+1; i++ )
					for ( int j=0; j < categories.max_index+1; j++ )
						for ( int k=0; k < actions.max_index+1; k++ )
							state_action_value[i][j][k] = input.readFloat();
			}
			catch ( IOException e ) { 
				System.out.println( "Can't open save File- load blanck\n" );
				for ( int i=0; i < states.max_index+1; i++ )
					for ( int j=0; j < categories.max_index+1; j++ )
						for ( int k=0; k < actions.max_index+1; k++ )
							state_action_value[i][j][k] = (float)0;

			}
		}// of save
}	