#  _________________________________________________________________________
#
#  Coopr: A COmmon Optimization Python Repository
#  Copyright (c) 2008 Sandia Corporation.
#  This software is distributed under the BSD License.
#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
#  the U.S. Government retains certain rights in this software.
#  For more information, see the Coopr README.txt file.
#  _________________________________________________________________________

#
# Object that stores information parsed from an 
# AMPL *.mod file.
#
# NOTE: this does not store node/arc info, but it could probably be
# adapted for that purpose
# 

import re
from mapfile_parser import parse_mapfile

class AmplInfo:

    def __init__(self):
        self.items = []
        self.supersets = {}
        self.dimen = {}
        self.set = {}
        self.param = {}
        self.var = {}
        self.min = {}
        self.max = {}
        self.con = {}
        self.exported_symbols = set()
        self.symbols = set()
        self.mapfile_decl={}
        self.concrete={}        # Only concrete values are printed in the AMPL model

    def add(self, dtype, name, indices, superset=None, dimen=0,_concrete=True):
        """Add a symbol"""
        if name in self.mapfile_decl:
            self.concrete[name]=True
            return
        if dtype != "set":
            dimen = 1
        if dtype == "minimize":
            dtype = "min"
        elif dtype == "maximize":
            dtype = "max"
        elif dtype == "subject" or dtype == "s.t.":
            dtype = "con"
        self.supersets[name] = superset
        self.dimen[name] = dimen
        self.concrete[name] = _concrete
        self.items.append( (dtype,name,indices,dimen) )
        self.symbols.add(name)
        if dtype == "set":
            self.set[name] = indices
        elif dtype == "param":
            self.param[name] = indices
        elif dtype == "var":
            self.var[name] = indices
        elif dtype == "min":
            self.min[name] = indices
        elif dtype == "max":
            self.max[name] = indices
        elif dtype == "con":
            self.con[name] = indices

    def add_mapfile_declaration(self,symbol,text):
        """ Parse mapfile declarations that are included in an AMPL model"""
        if '[' in symbol:
            name = symbol.split('[')[0]
        else:
            name=symbol
        if name in self.symbols:            #pragma:nocover
            raise IOError, "ERROR: sucasa declaration for variable "+name+" must occur before the AMPL declaration"
        self.add(re.split('[ \t]+',text.strip())[0],name,None,_concrete=False)
        self.mapfile_decl[name] = text

    def __str__(self):
        """Generate a string that represents the data in the AMPL model"""
        ans=[]
        for item in self.items:
            ans.append(item[0]+" "+item[1]+" "+str(getattr(self,item[0])[item[1]])+" "+str(self.supersets[item[1]]))
        return "\n".join(ans)

    def __iter__(self):
        """ Enable iteration through the list of symbols """
        return self.items.__iter__()

    def update_exports(self):
        """Perform error checking and update of exported_symbols list after symbols
        are added."""
        #
        # First, we verify that the exported symbols are valid
        #
        invalid=[]
        for item in self.exported_symbols:
            if not item in self.symbols and item != "*":
                invalid.append(item)
        if len(invalid) > 0:
            print "    ERROR: the following exported symbols are not valid"
            for item in invalid:
                print "      "+item
            raise IOError, "Invalid exported symbols"
        #
        # Next, we update the exported symbols if "*" as defined
        #
        if "*" in self.exported_symbols:
            self.exported_symbols = set()
            for item in self.items:
                self.exported_symbols.add(item[1])

    def initialize(self, symbolic_info, quiet=False):
        """ Initialize a MILPSymbInfo object with data parsed from AMPL"""
        self.update_exports()
        #
        # We add the variables, objectives constraints and exported 
        # symbols that were not generated by the SUCASA declarations
        #
        for item in self.items:
            bool1 = item[0] in ["var","min","max","con"]
            bool2 = "*" in self.exported_symbols or item[1] in self.exported_symbols
            if (bool1 or bool2) and item[1] in self.mapfile_decl:
                parse_mapfile(symbolic_info, data=self.mapfile_decl[item[1]]+";",add_temp_sets=True, debug=symbolic_info.verbose)
            elif bool1:
                if item[0] in ["min","max"]:
                    symbolic_info.add_symbol("con",item[1],index=item[2],tmpsets=True, superset="integers",dimen=item[3],quiet=quiet)
                else:
                    symbolic_info.add_symbol(item[0],item[1],index=item[2],tmpsets=True, superset="integers",dimen=item[3],quiet=quiet)
            elif bool2 and item[1] not in self.mapfile_decl:
                symbolic_info.add_symbol(item[0],item[1],index=item[2],tmpsets=True, superset=self.supersets[item[1]],dimen=item[3],quiet=quiet)

    def check(self, symbolic_info):
        """Validate the symbolic information"""
        flag=True
        for name in symbolic_info.symbol:
            if symbolic_info.stype[name] == "set":
                continue
            if not name in self.symbols:
                print "      ERROR: failed to find symbol '"+name+"' in AMPL parse info "
                flag=False
        for name in self.set:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        for name in self.param:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        for name in self.var:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        for name in self.min:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        for name in self.max:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        for name in self.con:
            if name in self.exported_symbols and name not in symbolic_info.symbol:
                print "      ERROR: failed to find symbol '"+name+"' in mapfile"
                flag=False
        return flag

