https://www.acmicpc.net/problem/1759
https://github.com/stellaluminary/Baekjoon
step을 밟아나가며 중복되지 않도록 visit list를 활용하고 사전식으로 정렬된 표현을 위해 ord를 이용한 비교를 조건에 내세운다.
이후 step이 원하는 길이에 도달했을 때 모음이 1이상이고 자음이 2이상인지를 체크한 후 맞다면 출력 그렇지 않다면 그냥 return함으로써 백트래킹을 완성한다.
def dfs(step):
global res
if step == l:
mo_cnt = 0
ja_cnt = 0
for i in res:
if i in mo:
mo_cnt += 1
else:
ja_cnt += 1
if mo_cnt >= 1 and ja_cnt >=2:
print(res)
return
for i in range(step, n):
if not visit[i] and (len(res) == 0 or ord(res[-1]) < ord(c[i])):
visit[i] = 1
res += c[i]
dfs(step+1)
visit[i] = 0
res = res[:-1]
import sys
input = sys.stdin.readline
l, n = map(int, input().split())
c = sorted(list(input().split()))
visit = [0] * n
res = ''
mo = ['a', 'e', 'i', 'o', 'u']
dfs(0)
방법 1 대비 조금 더 간소화한 코드이다.
특히 사전 순서대로 선행된다는 조건을 맞추기 위해 dfs parameter로 idx를 넣고 step을 더 깊이 들어갈 때
dfs (step+1, i+1)을 사용한다.
주의) idx+1이 아닌 i+1이다.
이렇게 해야만 앞에 있는 것만 본다.
이를 통한 visit 방문 처리도 불필요 → 삭제
def dfs(step, idx):
global res
if step == l:
aeiou = 0
for i in res:
if i in 'aeiou':
aeiou += 1
if aeiou >= 1 and len(res) - aeiou >=2:
print(res)
return
for i in range(idx, n):
res += c[i]
dfs(step+1, i+1)
res = res[:-1]
import sys
input = sys.stdin.readline
l, n = map(int, input().split())
c = sorted(list(input().split()))
res = ''
dfs(0, 0)
방법 2와 달리 깊이를 깊게하며 트리와 같이 넣을 것인지 넣지 않을 것인지에 대한 2개의 dfs로 step을 밟아나가는 방법이다.
종료 조건은 2가지로 깊이가 최대 길이일 때와 출력하고자 하는 total의 길이와 원하는 길이 l과 같아졌을 때이다.
def dfs(step, total):
if len(total) == l:
aeiou = 0
for i in total:
if i in 'aeiou':
aeiou += 1
if aeiou >= 1 and len(total) - aeiou >=2:
print(total)
return
if step == n:
return
dfs(step + 1, total+c[step])
dfs(step + 1, total)
import sys
input = sys.stdin.readline
l, n = map(int, input().split())
c = sorted(list(input().split()))
dfs(0, '')
itertools.combinations를 통해 문제를 해결하는 방법이다.
훨씬 간단하게 구할 수 있다.
import sys
from itertools import combinations
input = sys.stdin.readline
l, n = map(int, input().split())
c = sorted(list(input().split()))
for s in combinations(c, l):
m = 0
for i in range(l):
if s[i] in 'aeiou':
m += 1
if m >= 1 and l-m >= 2:
print(''.join(s))