やってみたらなんとかなる

プログラミングをする上で調べたこととかやったこととか

[Python] 座標圧縮ってなんだ! -AtCoder学習日記 #1

概要

今回からAtCoderのコンテストに参加して解けなかった問題の復習を書いておこうと思います。
AtCoderは基本Pythonで参加しています。 今回はABC213-Cです。

ACできないコード

import numpy
from sys import stdin

_, _, n = map(int, input().split())
low = numpy.empty(n, dtype=int)
col = numpy.empty(n, dtype=int)
for i in range(n):
  a, b = map(int, stdin.readline().split())
  low[i] = a
  col[i] = b

for i in range(n):
  print(numpy.count_nonzero(low <= low[i]), numpy.count_nonzero(col <= col[i]))

基本TLEで一部WAでした。

解説の解説

解説によると今回の問題は「座標圧縮」の問題とのことです。

そもそも「座標圧縮」とは?

座標圧縮はいらない部分を切り落とすことです。
主に情報のない部分が多いときに探索等を行う場合に使用するようです。
無駄な部分を無くすことで探索を高速化するんですね。TLEばかりの僕にピッタリダァ

座標圧縮のアルゴリズム

ある数列Aを座標圧縮することを考えます。

  1. Aをソートした数列をBとする
  2. Bの重複した値を消す
  3. Aの全要素についてBの何番目かを調べる

これで座標圧縮できます。トッテモカンタン!




このアルゴリズムを愚直に使ってコードを書いてみます。

import numpy
from sys import stdin

_, _, n = map(int, input().split())
low = numpy.empty(n, dtype=int)
col = numpy.empty(n, dtype=int)
for i in range(n):
  a, b = map(int, stdin.readline().split())
  low[i] = a
  col[i] = b
  
low_sorted = numpy.sort(numpy.unique(low))
col_sorted = numpy.sort(numpy.unique(col))

for i in range(n):
  print(numpy.where(low_sorted==low[i])[0][0]+1, numpy.where(col_sorted==col[i])[0][0]+1)

これでTLEだけになりました。ACではないです。
最後のlow_sorted、col_sortedの探索に時間がかかってると思われます。

模範回答コード

H,W,N=map(int,input().split())
X,Y=[],[]
for i in range(N):
  x,y=map(int,input().split())
  X.append(x)
  Y.append(y)

Xdict = {x:i+1 for i,x in enumerate(sorted(list(set(X))))}
Ydict = {y:i+1 for i,y in enumerate(sorted(list(set(Y))))}

for i in range(N):
  print(Xdict[X[i]], Ydict[Y[i]])

解説にある模範回答コードです。ソート後の数列の探索を辞書を使うことでほぼ無くすようにしています。
そんな解決方法があったのか…。まだまだ精進が足りないようです。

まとめ

  • 次元圧縮は無駄な情報を削ぎ落とすこと
  • インデックスの探索には辞書が使えることも