| from compute_perp import check_equal | |
| import multiprocessing, json, os, time | |
| def solve(predict, answer): | |
| cache_dict = {} | |
| m = len(predict) | |
| for i in range(m): | |
| key = str(predict[i]) + "<##>" + str(answer) | |
| rev_key = str(answer) + "<##>" + str(predict[i]) | |
| if key in cache_dict or rev_key in cache_dict: | |
| continue | |
| val = check_equal(predict[i], answer) | |
| cache_dict[key] = val | |
| cache_dict[rev_key] = val | |
| for i in range(m): | |
| for j in range(m): | |
| key = str(predict[i]) + "<##>" + str(predict[j]) | |
| rev_key = str(predict[j]) + "<##>" + str(predict[i]) | |
| if key in cache_dict or rev_key in cache_dict: | |
| continue | |
| val = check_equal(predict[i], predict[j]) | |
| cache_dict[key] = val | |
| cache_dict[rev_key] = val | |
| return cache_dict | |
| def cache(data, cache_path): | |
| if os.path.exists(cache_path): | |
| print(f"Cache file {cache_path} exists, skip!") | |
| return | |
| start_time = time.time() | |
| predicts = data["predict"] | |
| answers = data["answer"] | |
| n = len(predicts) | |
| cache_dict = {} | |
| with multiprocessing.Pool() as pool: | |
| results = pool.starmap( | |
| solve, [(predicts[i], answers[i]) for i in range(n)] | |
| ) | |
| for result in results: | |
| cache_dict.update(result) | |
| with open(cache_path, "w") as fw: | |
| json.dump(cache_dict, fw) | |
| print(f"Cache file {cache_path} built in {time.time() - start_time:.2f}S") |