#! /usr/bin/env python

#  _________________________________________________________________________
#
#  Coopr: A COmmon Optimization Python Repository
#  Copyright (c) 2010 Sandia Corporation.
#  This software is distributed under the BSD License.
#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
#  the U.S. Government retains certain rights in this software.
#  For more information, see the FAST README.txt file.
#  _________________________________________________________________________

import os
import random
from coopr.pysp.scenariotree import *
from coopr.pysp.phinit import *
from coopr.pysp.ph import *
from coopr.pysp.ef import *

def run(args=None):

# JPW: args=None is OK, as the arguments are by default propagated through sys.argv
#   if args == None:
#      print "Error: testconf run called with no args."
#      sys.exit(1)
   try:
      conf_options_parser = construct_ph_options_parser("testconf [options]")
      conf_options_parser.add_option("--Fraction-of-Scenarios-for-Solve",
                                     help="The fraction of scenarios that are allocated to finding a solution. Default is 0.5",
                                     action="store",
                                     dest="fraction_for_solve",
                                     type="float",
                                     default=0.5)
      conf_options_parser.add_option("--Number-of-Samples-for-Confidence-Interval",
                                     help="The number of samples of scenarios that are allocated to the confidence inteval (n_g). Default is 10",
                                     action="store",
                                     dest="n_g",
                                     type="int",
                                     default=10)
      conf_options_parser.add_option("--Confidence-Interval-Alpha",
                                     help="The alpha level for the confidence interval. Default is 0.10",
                                     action="store",
                                     dest="Conf_Alpha",
                                     type="float",
                                     default=0.10)
      conf_options_parser.add_option("--solve-hatx-with-ef-only",
                                     help="Perform hatx solve via EF rather than PH. Default is False",
                                     action="store_true",
                                     dest="solve_hatx_with_ef_only",
                                     default=False)      
      (options, args) = conf_options_parser.parse_args(args=args)
   except SystemExit:
      # the parser throws a system exit if "-h" is specified - catch
      # it to exit gracefully.
      return

   # HEY!!!!! Look at this!!!!
   random.seed(17)

   reference_model, reference_instance, full_scenario_tree, si = load_reference_and_scenario_models(options)

   # HERE: If we get an objective function with maximization, bail right away.

   print "Confidence Interval Calculations"
   print "full_scenario_tree pprint:"
   full_scenario_tree.pprint()
   scenariocount = len(full_scenario_tree._stages[-1]._tree_nodes)
   if len(full_scenario_tree._stages) > 2:
      print "Confidence intervals are available only for two stage problems. Stage count=",len(full_scenario_tree._stages)
      sys.exit(1)

   # random permutation of the indexes
   IndList = range(scenariocount)
   random.shuffle(IndList)
   print "After shuffle, IndList=",IndList
   
   num_scenarios_for_solution = int(options.fraction_for_solve * scenariocount)
   n_g = options.n_g
   num_scenarios_per_sample = int((scenariocount - num_scenarios_for_solution) / n_g)  #n in Morton's slides
   wastedscenarios = scenariocount - num_scenarios_for_solution - n_g * num_scenarios_per_sample
   
   print scenariocount, "Scenarios,from which ",num_scenarios_for_solution, " are to be used to find a solution and "
   print n_g," groups of ",num_scenarios_per_sample," are to be used for the confidence interval."
   if wastedscenarios > 0:
      print "(so ",wastedscenarios," will not be used.)"

   # create ph for finding the solution
   hatxph = ph_for_bundle(0, num_scenarios_for_solution, reference_model, full_scenario_tree, reference_instance, si, IndList, options)

   if options.solve_hatx_with_ef_only is True:
      # DLW: LOOK HERE!
      hatex_ef = create_ef_instance(hatxph._scenario_tree, hatxph._instances)
      ef_results = write_and_solve_ef(hatex_ef, hatxph._instances, options)
      load_ef_solution(ef_results, hatex_ef, hatxph._instances)
      hatxph._scenario_tree.snapshotSolutionFromInstances(hatxph._instances)
      hatxph._scenario_tree.pprintSolution()
   else:
      print "SOLVING HATX VIA PH!"
      hatxph.solve()
   print "so we now have the hat{x} variables in the PHAVG structure of hatxph"

   print "now form and solve the problems for each sample"
   # in order to handle the case of scenarios that are not equally likely, we will split the expectations for Gsupk
   # BUT we are going to assume that the groups themselves are equally likely and just scale by n_g and n_g-1 for Gbar and VarG
   G_supk_of_hatx = [] # really not always needed... http://www.eecs.berkeley.edu/~mhoemmen/cs194/Tutorials/variance.pdf
   Gbar = 0
   for k in range(1, n_g+1):
      start = num_scenarios_for_solution + (k-1)*num_scenarios_per_sample
      stop = start + num_scenarios_per_sample - 1
      print "Sample k=", k
      # NOTE: We'll never run this ph - code could be refactored
      phforGk = ph_for_bundle(start, stop, reference_model, full_scenario_tree, reference_instance, si, IndList, options)
      efforGk = create_ef_instance(phforGk._scenario_tree, phforGk._instances)      
      ef_results = write_and_solve_ef(efforGk, phforGk._instances, options)
      load_ef_solution(ef_results, efforGk, phforGk._instances)
      print "HEY JPW!!! we need to add or subtract the gap"
      # JPW doesn't like the 'f' below - seems bad hard-coding a literal.
      E_f_of_xstar = float(ef_results.solution(0).objective['f'].value)  ## DLW to JPW: we need the gap too
      print "E_f_of_xstar=",E_f_of_xstar
      efforgK_gap = ef_results.solution(0).gap # assuming this is the absolute gap
      E_f_of_xstar_bound = E_f_of_xstar - efforgK_gap # HERE - watch for signs on gap - is it absolute, or can it be negative?

      # to get fj(hatx) we need the obj value for each scenario, j, in the bundle evaluated at hat x
      # this is no small thing. A lot of code here is copied from ph.py
      action_handles = []
      scenario_action_handle_map = {} # maps scenario names to action handles
      action_handle_scenario_map = {} # maps action handles to scenario names
      E_f_of_Hatx = 0
      # to support parallel, loop to launch then loop to get and use results
      for scenario in  phforGk._scenario_tree._scenarios:
         instance = phforGk._instances[scenario._name]
         # do the fixing at hatx
         fix_first_stage_vars_for_instance_from_PHAVG(hatxph, instance)

         instance.preprocess()

         new_action_handle = phforGk._solver_manager.queue(instance, opt=phforGk._solver, tee=phforGk._output_solver_log)
         scenario_action_handle_map[scenario._name] = new_action_handle
         action_handle_scenario_map[new_action_handle] = scenario._name

         action_handles.append(new_action_handle)

      # loop to get and use results
      num_results_so_far = 0
      while (num_results_so_far < len(phforGk._scenario_tree._scenarios)):

         action_handle = phforGk._solver_manager.wait_any()
         results = phforGk._solver_manager.get_results(action_handle)         
         scenario_name = action_handle_scenario_map[action_handle]
         instance = phforGk._instances[scenario_name]         

         if phforGk._verbose is True:
            print "Sampling Results obtained for scenario="+scenario_name

         if len(results.solution) == 0:
            results.write(num=1)
            raise RuntimeError, "Sampling Solve failed for scenario="+scenario_name+"; no solutions generated"

         if phforGk._output_solver_results is True:
            print "Sampling Results for scenario=",scenario_name
            results.write(num=1)

         start_time = time.time()
         instance.load(results)
         end_time = time.time()
         if phforGk._output_times is True:
            print "Time loading results into instance="+str(end_time-start_time)+" seconds"

         if phforGk._verbose is True:                  
            print "Successfully loaded solution for scenario="+scenario_name

         num_results_so_far = num_results_so_far + 1


         objval = phforGk._scenario_tree.compute_scenario_cost(instance)
         E_f_of_Hatx += objval * scenario._probability

      G_supk_of_hatx.append(E_f_of_Hatx - E_f_of_xstar_bound)   
      Gbar += G_supk_of_hatx[k-1]
   Gbar = Gbar / n_g
   # second pass for variance calculation (because we like storing the G_supk)
   VarG = 0
   for k in range(0, n_g-1):
      VarG = VarG + (G_supk_of_hatx[k] - Gbar) * (G_supk_of_hatx[k] - Gbar)
   VarG = VarG / (n_g - 1)    # sample var
   print "Gbar=",Gbar
   print "VarG=",VarG      

#==============================================   
def ph_for_bundle(BundleStart, BundleStop, reference_model, full_scenario_tree, reference_instance, scenario_tree_instance, IndList, options):
   
   scenarios_to_bundle = []
   for i in range(BundleStart, BundleStop+1):   # python has zero based indexes
      scenarios_to_bundle.append(full_scenario_tree._scenarios[IndList[i]]._name)

   print "Scenarios to bundle:"+str(scenarios_to_bundle)

   scenario_tree_for_soln = ScenarioTree(scenarioinstance=reference_instance,
                            scenariotreeinstance=scenario_tree_instance,
                            scenariobundlelist=scenarios_to_bundle)

   if scenario_tree_for_soln.validate() is False:
      print "***ERROR: Bundle Scenario tree is invalid****"
      sys.exit(0)
   else:
      print "Scenario tree for solution is valid!"

   ph = create_ph_from_scratch(options, reference_model, reference_instance, scenario_tree_for_soln)
   return ph

#==============================================   
def fix_first_stage_vars_for_instance_from_PHAVG(ph, instance):
# use the values from PHAVG of the ph object to fix the first stage vars in the instance
   stage = ph._scenario_tree._stages[0]   # root node only
   # use the first instance from PH because all should have the same PHAVG
   phinstance = ph._instances[ph._scenario_tree._scenarios[0]._name]

   for (variable, index_template, variable_indices) in stage._variables:

       variable_name = variable.name
       variable_type = variable.domain
       # HERE: we want to not take anything from PHAVG, but rather from the scenario tree itself - the
       #       only difficulty is finding the root node object.
       avg_parameter_name = "PHAVG_"+variable_name
       avg_parameter = getattr(phinstance, avg_parameter_name)

       for index in variable_indices:
          if getattr(phinstance,variable_name)[index].status != VarStatus.unused:
             fix_value = avg_parameter[index]()
             if isinstance(variable_type, IntegerSet) or isinstance(variable_type, BooleanSet):
                fix_value = int(round(fix_value))
             getattr(instance,variable.name)[index].fixed = True
             getattr(instance,variable.name)[index].value = fix_value
   print "DLW says: first stage vars are fixed; maybe we need to delete any constraints with only first stage vars due to precision issues"

#==============================================   
def write_and_solve_ef(master_instance, scenario_instances, options):

   ef_file_name = "dlwef.lp"
   write_ef(master_instance, scenario_instances, os.path.expanduser(ef_file_name))

   print ""
   print "Solving extensive form written to file="+os.path.expanduser(ef_file_name)
   print ""

   ef_solver = SolverFactory(options.solver_type)
   if ef_solver is None:
      raise ValueError, "Failed to create solver of type="+options.solver_type+" for use in extensive form solve"
   if len(options.ef_solver_options) > 0:
      print "Initializing ef solver with options="+str(options.ef_solver_options)         
      ef_solver.set_options("".join(options.ef_solver_options))
   if options.ef_mipgap is not None:
      if (options.ef_mipgap < 0.0) or (options.ef_mipgap > 1.0):
         raise ValueError, "Value of the mipgap parameter for the EF solve must be on the unit interval; value specified=" + `options.ef_mipgap`
      else:
         ef_solver.mipgap = options.ef_mipgap

   ef_solver_manager = SolverManagerFactory(options.solver_manager_type)
   if ef_solver is None:
      raise ValueError, "Failed to create solver manager of type="+options.solver_type+" for use in extensive form solve"

   print "Queuing extensive form solve"
   ef_action_handle = ef_solver_manager.queue(os.path.expanduser(ef_file_name), opt=ef_solver, warmstart=False, tee=options.output_ef_solver_log)
   print "Waiting for extensive form solve"
   ef_results = ef_solver_manager.wait_for(ef_action_handle)
#   print "Extensive form solve results:"
#   ef_results.write(num=1)
   return ef_results

#=================================================================================================
# JPW: Any executable code can be placed at the end of a file, and it will be executed.
print "STARTING TO RUN"
run()
print "DONE"
