이번 글에서는 itertools.groupby 메서드의 함정에 대해 소개하고자 한다. python 에서 itertools를 사용한다면 이 글을 자세히 보는것을 추천한다.
결론부터 말하자면 itertools.groupby를 쓸 땐, 정렬된 컬렉션을 사용해야 한다. 그렇지 않으면 집계 결과가 잘못될 수 있다. itertools.groupby는 흔히 아는 sql의 groupby처럼 동작하지 않기 때문이다.
배경
itertools.groupby()를 사용하는 로직에서 groupby만 하면 DB에서 가져온 값 일부가 사라지는 문제가 있었다. 원인을 파악하며 함정이 있는것을 알게 되었다.
사전 지식
- Itertools
- 파이썬 라이브러리로, 효율적인 알고리즘을 사용해 컬렉션 데이터를 원하는 형태의 iterator를 만들어 주는 역할을 한다.
상황
아래처럼 sample_data가 주어졌을 때, groupby를 하면 어떤 결과가 나올까?
import itertools
sample_data = ["A", "A", "B", "C", "A"]
aggr_data = {k: len(list(v)) for k, v in itertools.groupby(sample_data, key=lambda x: x)}
print(aggr_data)
결론부터 말하면, 아래와 같은 결과가 나오게 된다.
{'A': 1, 'B': 1, 'C': 1}
이상하다. 흔히 알고있는 groupby의 동작(sql의 groupby)이라면 아래와 같은 결과가 나와야 한다.
{'A': 3, 'B': 1, 'C': 1}
즉 gropuby를 사용하면서 A의 값이 제대로 집계되지 않은 것이다.
원인 분석
공식 문서 조사
itertools의 공식 문서에 groupby()에는 다음과 같은 내용이 있다.
The operation of groupby() is similar to the uniq filter in Unix. It generates a break or new group every time the value of the key function changes (which is why it is usually necessary to have sorted the data using the same key function). That behavior differs from SQL’s GROUP BY which aggregates common elements regardless of their input order.
간단하게 내용을 요약해보면 다음과 같다.
- sql의 group by와 다르게 동작한다.
- key가 바뀔때마다 key에 대한 집계를 처음부터 새로 시작한다.
문제 상기
현재 겪고 있는 문제를 다시 상기시켜보자.
- ["A", "A", "B", "C", "A"] 컬렉션이 주어졌을 때, itertools로 groupby를 하면 기대한대로값이 이 나오지 않는 문제가 있다.
- 기대한 결과: {'A': 3, 'B': 1, 'C': 1}
- 실제 결과: ({'A': 1, 'B': 1, 'C': 1})
실험
주어진 컬랙션은 정렬되지 않은 상태이다. itertools.groupby는 key 가 바뀔때마다 해당 그룹의 집계를 처음부터 다시한다.
It generates a break or new group every time the value of the key function changes (which is why it is usually necessary to have sorted the data using the same key function).
그렇다면 정렬된 값을 주어졌을때는 어떻게 될까? 기대한 값을 반환할 것이라는 가설을 세우고 검증해보겠다.
import itertools
sample_data = ["A", "A", "A", "B", "C"]
aggr_data = {k: len(list(v)) for k, v in itertools.groupby(sample_data, key=lambda x: x)}
print(aggr_data)
상기 코드의 실행 결과는 다음과 같다. 정렬된 컬렉션을 groupby하면 기대한대로 결과가 나오고 있음이 검증되었다.
{'A': 3, 'B': 1, 'C': 1}
결론
정렬된 컬렉션을 제공했을 때 itertools.groupby() 값을 누락없이 집계함을 확인할 수 있었다. 반대로 정렬되지 않은 값은, 집계가 잘못 되는 문제가 있었다. 따라서 itertools.groupby를 사용할땐 정렬된 컬렉션을 사용해야 한다.
itertools.groupby()의 대안
itertools.groupby()를 꼭 사용할 필요가 없다면, defaultdict를 활용해 groupby를 구현하자. 가장 성능이 좋은 방법이다.
from collections import defaultdict
sample_data = ["A", "A", "B", "C", "A"]
aggr_data = defaultdict(str)
for key in sample_data.items():
aggr_data[key] = value
print(aggr_data)
defaultdict를 사용했을 때 성능이 얼마나 좋은걸까? itertools, defaultdict, pure python dict 3가지 로직의 실행 속도를 비교해보았다.
input 데이터는 아래처럼 A, B, C, D, E 알파벳들을 백만개씩 나열한 컬렉션이다. python3.10에서 이 데이터를 groupby 하는데 걸린 시간을 측정해 보았다.
values = ["A", "B", "C", "D", "E"]
sample_data = [random.choice(values) for _ in range(1000000)]
실행 결과는 다음과 같다. defaultdict가 제일 빠른 성능을 보여주고 python pure dict, itertools 가 그 뒤를 잇는다.
| 횟수 | python pure dict | defaultdict | itertools.groupby() |
| 1 | 51.35 | 41.82 | 119.90 |
| 2 | 51.02 | 41.68 | 116.09 |
| 3 | 53.89 | 44.19 | 118.97 |
| 4 | 61.94 | 51.73 | 116.92 |
| 5 | 57.28 | 47.96 | 117.18 |
| 평균 | 56.03 ms | 45.47 ms | 117.81 ms |
실험에 사용한 코드는 다음과 같다.
import itertools
import random
from collections import defaultdict
import time
values = ["A", "B", "C", "D", "E"]
sample_data = [random.choice(values) for _ in range(1000000)]
def measure_execution_time(method):
def wrapper(*args, **kwargs):
start_time = time.time()
result = method(*args, **kwargs)
end_time = time.time()
print(f"{(end_time - start_time) * 1000:.3f} ms")
return result
return wrapper
@measure_execution_time
def groupby_pure_python_dict(values):
aggr_data = {}
for key in values:
aggr_data[key] = aggr_data.get(key, 0) + 1
return aggr_data
@measure_execution_time
def groupby_defaultdict(values):
aggr_data = defaultdict(int)
for key in values:
aggr_data[key] +=1
return aggr_data
@measure_execution_time
def groupby_itertools(values):
values = sorted(values)
return {k: len(list(v)) for k, v in itertools.groupby(values, key=lambda x: x)}
_ = groupby_pure_python_dict(sample_data)
_ = groupby_defaultdict(sample_data)
_ = groupby_itertools(sample_data)