Flair를 이용한 Text Classification

본 글은 다음 글을 참고하여 작성했습니다.

https://towardsdatascience.com/text-classification-with-state-of-the-art-nlp-library-flair-b541d7add21f

Flair 설치

Flair는 python 3.6 이상에서만 동작한다.

$ pip install flair

트레이닝 된 모델 테스트 해보기

Flair에서 제공하는 트레이닝 된 모델은 텍스트에서 긍/부정 정도를 측정 할 수 있는 모델이다. 다음 코드로 테스트 할 수 있다.

from flair.models import TextClassifier
from flair.data import Sentence

classifier = TextClassifier.load('en-sentiment')
sentence = Sentence('Flair is pretty neat!')
classifier.predict(sentence)

# print sentence with predicted labels
print('Sentence above is: ', sentence.labels)

위 코드를 수행하면 Flair에서 미리 준비된 모델을 다운로드 한다. 용량이 꽤 커서 다운로드 하는데 몇 분 정도 소요될 수 있다. 프로그램 수행 결과는 대략 다음과 같다.

$ python test.py 
2019-12-17 10:56:35,585 loading file /Users/victor/.flair/models/imdb-v0.4.pt
Sentence above is:  [POSITIVE (0.6636102199554443)]

위 결과에서 디렉토리 패스를 보면 알 수 있듯이 .flair라는 디렉토리를 생성하고 하위에 imdb-v0.4.pt 모델이 다운로드 된 것을 알 수 있다.

커스텀 트레이닝 해보기

이제 새로운 데이터셋을 이용해서 직접 트레이닝 하는 방법을 알아보자.

Flair는 페이스북의 FastText format과 같은 데이터셋 구조를 사용한다. 해당 구조는 다음과 같이 생겼다.

__label__<class_1> <text>
__label__<class_2> <text>

위와 같이 __label__<class x> 와 같이 label을 나누고 그 뒤에 데이터에 해당하는 <text>가 자리한다. 예를 들면 다음과 같다.

__label__happy 나는 너무 행복해요.

그러면 kaggle에서 제공하는 spam 데이터셋을 이용해서 스팸 문자를 분류하는 모델을 만들어보자. 데이터셋을 받아서 에디터로 열어보면 다음과 같은 구조로 되어 있다.

v1,v2,,,
ham,"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",,,
ham,Ok lar... Joking wif u oni...,,,
spam,Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's,,,
ham,U dun say so early hor... U c already then say...,,,
ham,"Nah I don't think he goes to usf, he lives around here though",,,
spam,"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, å£1.50 to rcv",,,
ham,Even my brother is not like to speak with me. They treat me like aids patent.,,,

데이터를 보면 v1, v2 column 레이블이 있고 그 뒤로는 null로 된 칼럼이 두개 있다. 우리가 필요한 데이터는 v1, v2 column들이다. 따라서 v1을 __label__spam, __label__ham으로 변경해야한다.

Preprocessing – 데이터셋 변환 작업

앞서 보였던 데이터를 Flair가 인식할 수 있는 형태로 변경하기 위해 Pandas 패키지를 이용한다. 설치는 pip install pandas 로 할 수 있다. 데이터셋 변환 코드는 다음과 같다.

import pandas as pd

# 데이터셋 읽기
data = pd.read_csv("./spam.csv", encoding='latin-1').sample(frac=1).drop_duplicates()

# column 레이블 이름을 v1, v2에서 label, text로 변경
data = data[['v1', 'v2']].rename(columns={"v1":"label", "v2":"text"})

# label column에 있는 데이터들을 __label__<class x> 형태로 변환하기 
data['label'] = '__label__' + data['label'].astype(str)

# 변환한 데이터들을 train.csv, test.csv, dev.csv 로 각각 0.8, 0.1, 0.1 비율 개수로 나누기
data.iloc[0:int(len(data)*0.8)].to_csv('train.csv', sep='\t', index = False, header = False)
data.iloc[int(len(data)*0.8):int(len(data)*0.9)].to_csv('test.csv', sep='\t', index = False, header = False)
data.iloc[int(len(data)*0.9):].to_csv('dev.csv', sep='\t', index = False, header = False);

위와 같이 수행하면 train.csv, test.csv, dev.csv 데이터 파일들이 생성되며 그 중 하나를 열어보면 결과는 다음과 같다.

__label__ham    Of course. I guess god's just got me on hold right now.
__label__ham    I'm outside islands, head towards hard rock and you'll run into me
__label__spam   8007 FREE for 1st week! No1 Nokia tone 4 ur mob every week just txt NOKIA to 8007 Get txting and tell ur mates www.getzed.co.uk POBox 36504 W4 5WQ norm 150p/tone 16+
__label__ham    Water logging in desert. Geoenvironmental implications.
__label__ham    "What part of \don't initiate\"" don't you understand"""
__label__ham    How do friends help us in problems? They give the most stupid suggestion that Lands us into another problem and helps us forgt the previous problem
__label__ham    For my family happiness..
__label__ham    Sorry, I'll call later

Text Classification Model 트레이닝 하기

데이터셋이 준비되었으니 이제 모델을 트레이닝 해보자. 트레이닝 코드는 다음과 같다.

from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path

corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'), test_file='test.csv', dev_file='dev.csv', train_file='train.csv')

word_embeddings = [WordEmbeddings('glove'), FlairEmbeddings('news-forward-fast'), FlairEmbeddings('news-backward-fast')]

document_embeddings = DocumentLSTMEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)

classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)

trainer = ModelTrainer(classifier, corpus)
trainer.train('./', max_epochs=10)

위와 같이 트레이닝을 완료하면 best-model.pt 파일이 생성된다.

Prediction 해보기

새로 만든 model을 테스트 해보자. 다음 코드를 활용해서 테스트 할 수 있다.

from flair.models import TextClassifier
from flair.data import Sentence

classifier = TextClassifier.load('./best-model.pt')
sentence = Sentence('Hi. Yes mum, I will...')
classifier.predict(sentence)
print(sentence.labels)

위 코드의 실행 결과는 다음과 같다.

[ham (0.999873161315918)]

Leave a Reply