[문제]
때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.
행성은 3차원 좌표위의 한 점으로 생각하면 된다.
두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.
민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.
[입력 조건]
- 첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다.
- 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다.
[출력 조건]
- 첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.
실패한 방법
import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*
import kotlin.math.abs
private lateinit var parent : IntArray
fun main(){
val br = BufferedReader(InputStreamReader(System.`in`))
var stk : StringTokenizer
val n = br.readLine().toInt()
parent = IntArray(n){ i -> i }
val starArray = Array(n){ Triple(0, 0, 0) } // a별 x좌표, a별 y좌표 a별 z좌표
val edgeList : MutableList<Triple<Int, Int, Int>> = mutableListOf() // a별, b별, a->b cost
repeat(n){ i ->
stk = StringTokenizer(br.readLine())
starArray[i] = Triple(stk.nextToken().toInt(), stk.nextToken().toInt(), stk.nextToken().toInt())
repeat(i){ j ->
edgeList.add(Triple(i, j, getDistance(starArray[i], starArray[j])))
}
}
edgeList.sortWith(compareBy{ it.third })
var totalCost = 0
edgeList.forEach{
val a = it.first
val b = it.second
val cost = it.third
if(findParent(a) != findParent(b)){
union(a, b)
totalCost += cost
}
}
print(totalCost)
}
private fun getDistance(a : Triple<Int, Int, Int>, b : Triple<Int, Int, Int>): Int{
return minOf(abs(a.first - b.first), abs(a.second - b.second), abs(a.third - b.third))
}
private fun findParent(x: Int): Int{
if(x == parent[x]) return x
parent[x] = findParent(parent[x])
return parent[x]
}
private fun union(a : Int, b: Int) {
val A = findParent(a)
val B = findParent(b)
if (A == B) return
parent[B] = A
}
행성간의 거리를 구할 수 있는 getDistance() 함수를 따로 만들어, 행성과 다른 행성간의 모두 간선을 연결하고, 간선을 기준으로 정렬한 후 행성 두 개가 루트 노드가 다른 경우 결과에 해당 간선의 거리를 더해주는 방식으로 구현하였다. 해당 풀이는 메모리 초과 오류가 발생하였고, 입력 조건을 보면 이유를 알 수 있었다. 행성의 갯수는 총 10^5 이하로 최대 간선의 갯수는 (10^5) * (10^ 5 - 1) / 2 개가 발생해 메모리 초과가 발생한다.
해결한 방법
import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*
private data class Planet(
val x : Int,
val y : Int,
val z : Int,
val i : Int
)
private lateinit var parent : IntArray
fun main(){
val br = BufferedReader(InputStreamReader(System.`in`))
var stk : StringTokenizer
val n = br.readLine().toInt()
parent = IntArray(n){ i -> i }
val starArray = Array(n){ Planet(0, 0, 0, 0) } // a별 x좌표, a별 y좌표 a별 z좌표, 인덱스
val edgeList : MutableList<Triple<Int, Int, Int>> = mutableListOf() // a별, b별, a->b cost
repeat(n){ i ->
stk = StringTokenizer(br.readLine ())
val x = stk.nextToken().toInt()
val y = stk.nextToken().toInt()
val z = stk.nextToken().toInt()
starArray[i] = Planet(x, y, z, i)
}
val xList = starArray.sortedWith(compareBy{ it.x })
val yList = starArray.sortedWith(compareBy{ it.y })
val zList = starArray.sortedWith(compareBy{ it.z })
for(i in 0 until n - 1){
edgeList.add(Triple(xList[i].i, xList[i + 1].i , xList[i + 1].x - xList[i].x))
edgeList.add(Triple(yList[i].i, yList[i + 1].i, yList[i + 1].y - yList[i].y))
edgeList.add(Triple(zList[i].i, zList[i + 1].i, zList[i + 1].z - zList[i].z))
}
edgeList.sortWith( compareBy{ it.third } )
var totalCost = 0
edgeList.forEach{
val a = it.first
val b = it.second
val cost = it.third
if(findParent(a) != findParent(b)){
union(a, b)
totalCost += cost
}
}
print(totalCost)
}
private fun findParent(x: Int): Int{
if(x == parent[x]) return x
parent[x] = findParent(parent[x])
return parent[x]
}
private fun union(a : Int, b: Int) {
val A = findParent(a)
val B = findParent(b)
if (A == B) return
parent[B] = A
}
메모리 초과를 피하기 위해서 간선의 갯수를 줄여야 했다. 행성간 거리를 도출하는 식으로부터 힌트를 얻을 수 있었다. 행성간 거리는 x좌표의 차, y좌표의 차, z좌표의 차 중 가장 작은 값이 거리가 된다. 그 말은 하나의 행성에서 연결될 수 있는 행성은 x, y, z 좌표중 하나가 가장 가깝다는 표현으로 이해할 수 있다. 행성을 정렬해도 순서가 섞이는 것을 방지하기 위해 Planet이라는 dataClass를 새로 생성했고, x, y, z, i(행성의 번호)를 함께 저장했다. 행성을 x, y, z로 정렬하고 하나의 행성과 다음 인덱스의 행성과 간선 값을 edgeList에 저장했다. edgeList를 간선의 값에 따라 정렬하고 edgeList에서 뽑은 두 행성의 루트 노드가 다른 경우 해당 간선의 값을 결과 값에 더하는 방식으로 결과를 도출했다.