spadyのメモ帳

技術ブログにしたいけどどうなることやら。まだ素人

焼き鈍しのビジュアライズ

さて、SuperConで盛大に焼き鈍しの実装に失敗したので勉強しようと思います。

en.wikipedia.org

Wikiみると良さげなgifがありますね。これを再現してみることにします。

まず、このグラフはなんぞやってことなんですが、調べてもわからなかったのでそれっぽいグラフを設定します。

今回は地形の断面図の標高で代用します。国土地理院地理院地図でカラチ(パキスタン)からイスタンブール(トルコ)の断面図を用意しました。この最高点を探すことにします。 f:id:spady:20210828121930p:plain

実装で参考にした記事はこちらです。 gasin.hatenadiary.jp

実装コードはこちら

#include <bits/stdc++.h>
using namespace std;
const int INF = 1 << 30;

//カラチからイスタンブールまでの断面図標高 by 国土地理院
int elevation[] = {
    30,   35,   95,   26,   3,    81,   87,   144,  217,  234,  101,  956,
    268,  447,  500,  662,  574,  895,  1071, 979,  1016, 1398, 1107, 1043,
    998,  920,  962,  1002, 940,  1351, 768,  1051, 1000, 905,  819,  780,
    759,  723,  860,  959,  1052, 1251, 1309, 1105, 1368, 1415, 1841, 1625,
    1543, 1556, 2037, 2228, 1939, 2157, 2087, 1796, 1591, 1526, 1474, 1459,
    1359, 1567, 810,  518,  475,  500,  524,  508,  402,  298,  272,  269,
    303,  320,  352,  289,  336,  323,  313,  310,  301,  383,  511,  1388,
    1159, 1228, 2458, 1750, 1413, 1583, 2066, 2445, 2663, 1976, 2040, 1890,
    1920, 1956, 1789, 1433, 1200, 1034, 929,  925,  1006, 1358, 1808, 2037,
    1861, 1533, 1440, 1070, 968,  963,  965,  986,  1042, 1183, 1312, 1486,
    1801, 1973, 2060, 2349, 2029, 1449, 1313, 1324, 1296, 1499, 1412, 2006,
    2061, 1867, 2429, 1614, 2828, 2329, 2342, 1676, 1383, 1421, 1948, 1785,
    1863, 1884, 1861, 1824, 1870, 2036, 1741, 1715, 1617, 1719, 1847, 1968,
    1697, 2012, 2292, 2049, 2037, 1801, 1773, 1741, 1968, 2027, 1997, 1911,
    2067, 2116, 2308, 2421, 1817, 1717, 1957, 2040, 2146, 1650, 1974, 1197,
    1469, 1922, 2286, 1406, 1360, 820,  1450, 855,  670,  673,  1094, 1344,
    1613, 1721, 1185, 881,  982,  968,  913,  1484, 868,  1343, 1991, 900,
    1275, 1247, 908,  848,  928,  660,  909,  710,  656,  667,  575,  695,
    734,  704,  743,  889,  948,  875,  1016, 1343, 1440, 1475, 1285, 1372,
    1394, 1657, 1179, 707,  796,  1198, 1157, 1254, 1511, 1410, 1622, 1505,
    1571, 1764, 1741, 1873, 1535, 1379, 1252, 1179, 1663, 1702, 1641, 1185,
    1094, 1155, 1122, 1061, 1193, 1122, 922,  953,  931,  863,  761,  1000,
    1120, 1144, 884,  1383, 1166, 1022, 1041, 963,  1249, 1075, 1390, 1543,
    1159, 1503, 1438, 1154, 948,  1226, 1311, 1648, 426,  33,   30,   36,
    315,  196,  210,  368,  277,  107,  150,  76};

struct Timer {
   public:
    Timer() { restart(); }

    void restart() { m_start = std::chrono::steady_clock::now(); }

    auto elapsed() {
        std::chrono::steady_clock::time_point en =
            std::chrono::steady_clock::now();
        auto dur = en - m_start;
        return std::chrono::duration_cast<std::chrono::milliseconds>(dur)
            .count();
    }

   private:
    std::chrono::_V2::steady_clock::time_point m_start;
};

int main() {
    Timer tim;
    std::random_device seed_gen;
    std::default_random_engine engine(seed_gen());

    // 0以上9以下の値を等確率で発生させる
    std::uniform_int_distribution<> dist(0, 295);

    const int TIME_LIMIT = 2 * 1000;
    double start_temp = 1000, end_temp = 10;  // 適当な値を入れる(後述)
    long long  start_time = tim.elapsed();        // 開始時刻

    int pre_score = elevation[dist(engine)];
    // mountain
    while (tim.elapsed() < TIME_LIMIT) {
        int d = dist(engine);
        int new_score = elevation[d];
        long long now_time = tim.elapsed();
        // 温度関数
        double temp = start_temp + (end_temp - start_temp) *
                                       (long long)(now_time - start_time) / TIME_LIMIT;
        // 遷移確率関数(最大化の場合)
        double prob = exp((new_score - pre_score) / temp);

        if (prob > (rand()%INF)/(double)INF) {
            pre_score = new_score;
            cerr << d << "\n";
            cout << "Score : " << pre_score << "\n";
        }
    }
    cout << "Final Score : " << pre_score << "\n";
}

これをmatplotlibでビジュアライズしました。 コードはこちらです。

from matplotlib import pyplot as plt
from matplotlib import animation as animation

fig = plt.figure()
ims = []

x = list(range(296))

x1 = [30, 35, 95, 26, 3, 81, 87, 144, 217, 234, 101, 956, 268, 447, 500, 662, 574, 895, 1071, 979, 1016, 1398, 1107, 1043, 998,
      920, 962, 1002, 940, 1351, 768, 1051, 1000, 905, 819, 780, 759, 723, 860, 959, 1052, 1251, 1309, 1105, 1368, 1415, 1841, 
      1625, 1543, 1556, 2037, 2228, 1939, 2157, 2087, 1796, 1591, 1526, 1474, 1459, 1359, 1567, 810, 518, 475, 500, 524, 508, 
      402, 298, 272, 269, 303, 320, 352, 289, 336, 323, 313, 310, 301, 383, 511, 1388, 1159, 1228, 2458, 1750, 1413, 1583, 2066,
      2445, 2663, 1976, 2040, 1890, 1920, 1956, 1789, 1433, 1200, 1034, 929, 925, 1006, 1358, 1808, 2037, 1861, 1533, 1440, 1070, 
      968, 963, 965, 986, 1042, 1183, 1312, 1486, 1801, 1973, 2060, 2349, 2029, 1449, 1313, 1324, 1296, 1499, 1412, 2006, 2061, 
      1867, 2429, 1614, 2828, 2329, 2342, 1676, 1383, 1421, 1948, 1785, 1863, 1884, 1861, 1824, 1870, 2036, 1741, 1715, 1617, 1719,
      1847, 1968, 1697, 2012, 2292, 2049, 2037, 1801, 1773, 1741, 1968, 2027, 1997, 1911, 2067, 2116, 2308, 2421, 1817, 1717, 1957,
      2040, 2146, 1650, 1974, 1197, 1469, 1922, 2286, 1406, 1360, 820, 1450, 855, 670, 673, 1094, 1344, 1613, 1721, 1185, 881, 982, 
      968, 913, 1484, 868, 1343, 1991, 900, 1275, 1247, 908, 848, 928, 660, 909, 710, 656, 667, 575, 695, 734, 704, 743, 889, 948, 
      875, 1016, 1343, 1440, 1475, 1285, 1372, 1394, 1657, 1179, 707, 796, 1198, 1157, 1254, 1511, 1410, 1622, 1505, 1571, 1764, 
      1741, 1873, 1535, 1379, 1252, 1179, 1663, 1702, 1641, 1185, 1094, 1155, 1122, 1061, 1193, 1122, 922, 953, 931, 863, 761, 1000, 
      1120, 1144, 884, 1383, 1166, 1022, 1041, 963, 1249, 1075, 1390, 1543, 1159, 1503, 1438, 1154, 948, 1226, 1311, 1648, 426, 33, 
      30, 36, 315, 196, 210, 368, 277, 107, 150, 76]

m = []



cnt = 0
fig, ax = plt.subplots(figsize=(15,7),dpi=50)
while True:
    if cnt > 0:
        plt.cla()
    cnt += 1
    try:
        input_ = input()
    except EOFError:
        break
    im = plt.plot(x,x1,color="gray")
    im2 = plt.plot([int(input_),int(input_)],[0,3000],color="red")
    text = ax.text(0,3000,cnt,size=15,color="green")
    text2 = ax.text(20,3000,x1[int(input_)],size=15,color="green")
    ims.append(im +im2 + [text] + [text2])
    plt.savefig("./img/{}".format(cnt))

TLを2秒にしたんですがそれでも6000枚の画像が出力されました。

ビジュアライズした結果は下の動画にしました。 youtu.be

用意したデータの最高値が2828であるので、焼き鈍しで確かに最高値が得られています。

ちなみに、今回のデータでは量が少なすぎたのか山登り法でも最高値を見つけられています。 ただ、ビジュアライズが目的だったので目的は達成できました。