improve GetVarValueMap

This commit is contained in:
bollhals
2022-01-25 01:16:38 +01:00
parent ab519a8339
commit fd67eabfdd
2 changed files with 42 additions and 32 deletions

View File

@@ -26,7 +26,8 @@ public class CpModel
{
model_ = new CpModelProto();
constant_map_ = new Dictionary<long, int>();
terms_ = new Queue<Term>();
var_value_map_ = new Dictionary<int, long>(10);
terms_ = new Queue<Term>(10);
}
// Getters.
@@ -121,17 +122,15 @@ public class CpModel
private long FillLinearConstraint(LinearExpr expr, out LinearConstraintProto linear)
{
linear = new LinearConstraintProto();
Dictionary<IntVar, long> dict = new Dictionary<IntVar, long>();
long constant = LinearExpr.GetVarValueMap(expr, 1L, dict, terms_);
var dict = var_value_map_;
dict.Clear();
long constant = LinearExpr.GetVarValueMap(expr, dict, terms_);
var count = dict.Count;
linear = new LinearConstraintProto();
linear.Vars.Capacity = count;
linear.Vars.AddRange(dict.Keys);
linear.Coeffs.Capacity = count;
foreach (KeyValuePair<IntVar, long> term in dict)
{
linear.Vars.Add(term.Key.Index);
linear.Coeffs.Add(term.Value);
}
linear.Coeffs.AddRange(dict.Values);
return constant;
}
/**
@@ -1099,27 +1098,28 @@ public class CpModel
}
else
{
Dictionary<IntVar, long> dict = new Dictionary<IntVar, long>();
long constant = LinearExpr.GetVarValueMap(obj, 1L, dict, terms_);
var dict = var_value_map_;
dict.Clear();
long constant = LinearExpr.GetVarValueMap(obj, dict, terms_);
var dictCount = dict.Count;
objective.Vars.Capacity = dictCount;
objective.Vars.AddRange(dict.Keys);
objective.Coeffs.Capacity = dictCount;
if (minimize)
{
objective.Coeffs.AddRange(dict.Values);
objective.ScalingFactor = 1L;
objective.Offset = constant;
}
else
{
foreach (var coeff in dict.Values)
{
objective.Coeffs.Add(-coeff);
}
objective.ScalingFactor = -1L;
objective.Offset = -constant;
}
var dictCount = dict.Count;
objective.Vars.Capacity = dictCount;
objective.Coeffs.Capacity = dictCount;
foreach (KeyValuePair<IntVar, long> it in dict)
{
objective.Vars.Add(it.Key.Index);
objective.Coeffs.Add(minimize ? it.Value : -it.Value);
}
}
model_.Objective = objective;
}
@@ -1191,25 +1191,36 @@ public class CpModel
internal LinearExpressionProto GetLinearExpressionProto(LinearExpr expr, bool negate = false)
{
Dictionary<IntVar, long> dict = new Dictionary<IntVar, long>();
long constant = LinearExpr.GetVarValueMap(expr, 1L, dict, terms_);
var dict = var_value_map_;
dict.Clear();
long constant = LinearExpr.GetVarValueMap(expr, dict, terms_);
long mult = negate ? -1 : 1;
LinearExpressionProto linear = new LinearExpressionProto();
var dictCount = dict.Count;
linear.Vars.Capacity = dictCount;
linear.Vars.AddRange(dict.Keys);
linear.Coeffs.Capacity = dictCount;
foreach (KeyValuePair<IntVar, long> term in dict)
if (!negate)
{
linear.Vars.Add(term.Key.Index);
linear.Coeffs.Add(term.Value * mult);
linear.Coeffs.AddRange(dict.Values);
linear.Offset = constant;
}
linear.Offset = constant * mult;
else
{
foreach (var coeff in dict.Values)
{
linear.Coeffs.Add(-coeff);
}
linear.Offset = -constant;
}
return linear;
}
private CpModelProto model_;
private Dictionary<long, int> constant_map_;
private Dictionary<int, long> var_value_map_;
private BoolVar true_literal_;
private Queue<Term> terms_;
}

View File

@@ -35,7 +35,7 @@ public interface ILiteral
internal static class HelperExtensions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void AddOrIncrement(this Dictionary<IntVar, long> dict, IntVar key, long increment)
public static void AddOrIncrement(this Dictionary<int, long> dict, int key, long increment)
{
#if NET6_0_OR_GREATER
System.Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out _) += increment;
@@ -360,11 +360,10 @@ public class LinearExpr
}
}
internal static long GetVarValueMap(LinearExpr e, long initial_coeff, Dictionary<IntVar, long> dict,
Queue<Term> terms)
internal static long GetVarValueMap(LinearExpr e, Dictionary<int, long> dict, Queue<Term> terms)
{
long constant = 0;
long coefficient = initial_coeff;
long coefficient = 1;
LinearExpr expr = e;
terms.Clear();
@@ -390,10 +389,10 @@ public class LinearExpr
}
break;
case IntVar intVar:
dict.AddOrIncrement(intVar, coefficient);
dict.AddOrIncrement(intVar.GetIndex(), coefficient);
break;
case NotBoolVar notBoolVar:
dict.AddOrIncrement((IntVar)notBoolVar.Not(), -coefficient);
dict.AddOrIncrement(notBoolVar.Not().GetIndex(), -coefficient);
constant += coefficient;
break;
default: