diff --git a/include/codecrypt.h b/include/codecrypt.h index d79626e..c26bd0c 100644 --- a/include/codecrypt.h +++ b/include/codecrypt.h @@ -126,6 +126,7 @@ public: uint add (uint, uint); uint mult (uint, uint); uint exp (uint, int); + uint exp (int); uint inv (uint); uint sq_root (uint); }; diff --git a/lib/decoding.cpp b/lib/decoding.cpp index 6ac5359..3d70bb5 100644 --- a/lib/decoding.cpp +++ b/lib/decoding.cpp @@ -65,3 +65,76 @@ bool evaluate_error_locator_dumb (polynomial&a, bvector&ev, gf2m&fld) return true; } + +/* + * berlekamp trace algorithm - we puncture roots of incoming polynomial into + * the vector of size fld.n + * + * Inspired by implementation from HyMES. + */ + +#include + +bool evaluate_error_locator_trace (polynomial&sigma, bvector&ev, gf2m&fld) +{ + ev.clear(); + ev.resize (fld.n, 0); + + std::vector trace_aux, trace; //trace cache + trace_aux.resize (fld.m); + trace.resize (fld.m); + + trace_aux[0] = polynomial(); + trace_aux[0].resize (2, 0); + trace_aux[0][1] = 1; //trace_aux[0] = x + trace[0] = trace_aux[0]; //trace[0] = x + + for (uint i = 1; i < fld.m; ++i) { + trace_aux[i] = trace_aux[i-1]; + trace_aux[i].square (fld); + trace_aux[i].mod (sigma, fld); + trace[0].add (trace_aux[i], fld); + } + + std::set > stk; //"stack" + + stk.insert (make_pair (0, sigma) ); + + while (!stk.empty() ) { + + uint i = stk.begin()->first; + polynomial cur = stk.begin()->second; + + stk.erase (stk.begin() ); + + int deg = cur.degree(); + + if (deg <= 0) continue; + if (deg == 1) { //found a linear factor + ev[fld.mult (cur[0], fld.inv (cur[1]) ) ] = 1; + continue; + } + + if (i >= fld.m) return false; + + if (trace[i].zero() ) { + //compute the trace if it isn't cached + uint a = fld.exp (i); + for (uint j = 0; j < fld.m; ++j) { + trace[i].add_mult (trace_aux[j], a, fld); + a = fld.mult (a, a); + } + } + + polynomial t; + t = cur.gcd (trace[i], fld); + polynomial q, r; + cur.divmod (t, q, r, fld); + + stk.insert (make_pair (i + 1, t) ); + stk.insert (make_pair (i + 1, q) ); + } + + return true; +} + diff --git a/lib/gf2m.cpp b/lib/gf2m.cpp index 222f396..03645df 100644 --- a/lib/gf2m.cpp +++ b/lib/gf2m.cpp @@ -139,6 +139,12 @@ uint gf2m::exp (uint a, int k) return r; } +uint gf2m::exp (int k) +{ + //return x^k + return exp (1 << 1, k); +} + uint gf2m::inv (uint a) { if (!a) return 0; diff --git a/lib/mce.cpp b/lib/mce.cpp index 926df3b..ead7c3a 100644 --- a/lib/mce.cpp +++ b/lib/mce.cpp @@ -83,7 +83,7 @@ int privkey::decrypt (const bvector&in, bvector&out) compute_error_locator (syndrome, fld, g, sqInv, loc); bvector ev; - if (!evaluate_error_locator_dumb (loc, ev, fld) ) + if (!evaluate_error_locator_trace (loc, ev, fld) ) return 1; //if decoding somehow failed, fail as well. // check the error vector, it should have exactly t == deg (g) errors @@ -151,7 +151,7 @@ int privkey::sign (const bvector&in, bvector&out, uint delta, uint attempts, prn compute_error_locator (synd, fld, g, sqInv, loc); - if (evaluate_error_locator_dumb (loc, e2, fld) ) { + if (evaluate_error_locator_trace (loc, e2, fld) ) { //create the decodable message p.add (e); diff --git a/lib/nd.cpp b/lib/nd.cpp index 0afdc8d..45ac4fc 100644 --- a/lib/nd.cpp +++ b/lib/nd.cpp @@ -62,7 +62,7 @@ int privkey::decrypt (const bvector&in, bvector&out) compute_error_locator (unsc, fld, g, sqInv, loc); bvector ev; - if (!evaluate_error_locator_dumb (loc, ev, fld) ) + if (!evaluate_error_locator_trace (loc, ev, fld) ) return 1; if ( (int) ev.hamming_weight() != g.degree() ) @@ -95,7 +95,7 @@ int privkey::sign (const bvector&in, bvector&out, uint delta, uint attempts, prn compute_error_locator (synd_unsc, fld, g, sqInv, loc); - if (evaluate_error_locator_dumb (loc, e, fld) ) { + if (evaluate_error_locator_trace (loc, e, fld) ) { Pinv.permute (e, out); return 0; diff --git a/lib/polynomial.cpp b/lib/polynomial.cpp index 433bbcf..c323ca1 100644 --- a/lib/polynomial.cpp +++ b/lib/polynomial.cpp @@ -112,6 +112,7 @@ bool polynomial::is_irreducible (gf2m&fld) const xmodf.mod (*this, fld); //mod f uint d = degree(); + if (d < 0) return false; for (uint i = 1; i <= (d / 2); ++i) { for (uint j = 0; j < fld.m; ++j) { t = xi;