Spaces:
Build error
Build error
| import random | |
| from typing import AnyStr | |
| import streamlit as st | |
| from bs4 import BeautifulSoup | |
| import numpy as np | |
| import base64 | |
| from spacy_streamlit.util import get_svg | |
| from custom_renderer import render_sentence_custom | |
| from flair.data import Sentence | |
| from flair.models import SequenceTagger | |
| import spacy | |
| from spacy import displacy | |
| from spacy_streamlit import visualize_parser | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from transformers import pipeline | |
| import os | |
| from transformers_interpret import SequenceClassificationExplainer | |
| # Map model names to URLs | |
| model_names_to_URLs = { | |
| 'ml6team/distilbert-base-dutch-cased-toxic-comments': | |
| 'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments', | |
| 'ml6team/robbert-dutch-base-toxic-comments': | |
| 'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments', | |
| } | |
| about_page_markdown = f"""# π€¬ Dutch Toxic Comment Detection Space | |
| Made by [ML6](https://ml6.eu/). | |
| Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret). | |
| """ | |
| regular_emojis = [ | |
| 'π', 'π', 'πΆ', 'π', | |
| ] | |
| undecided_emojis = [ | |
| 'π€¨', 'π§', 'π₯Έ', 'π₯΄', 'π€·', | |
| ] | |
| potty_mouth_emojis = [ | |
| 'π€', 'πΏ', 'π‘', 'π€¬', 'β οΈ', 'β£οΈ', 'β’οΈ', | |
| ] | |
| # Page setup | |
| st.set_page_config( | |
| page_title="Toxic Comment Detection Space", | |
| page_icon="π€¬", | |
| layout="centered", | |
| initial_sidebar_state="auto", | |
| menu_items={ | |
| 'Get help': None, | |
| 'Report a bug': None, | |
| 'About': about_page_markdown, | |
| } | |
| ) | |
| # Model setup | |
| def load_pipeline(model_name): | |
| with st.spinner('Loading model (this might take a while)...'): | |
| toxicity_pipeline = pipeline( | |
| 'text-classification', | |
| model=model_name, | |
| tokenizer=model_name) | |
| cls_explainer = SequenceClassificationExplainer( | |
| toxicity_pipeline.model, | |
| toxicity_pipeline.tokenizer) | |
| return toxicity_pipeline, cls_explainer | |
| # Auxiliary functions | |
| def format_explainer_html(html_string): | |
| """Extract tokens with attribution-based background color.""" | |
| inside_token_prefix = '##' | |
| soup = BeautifulSoup(html_string, 'html.parser') | |
| p = soup.new_tag('p', | |
| attrs={'style': 'color: black; background-color: white;'}) | |
| # Select token elements and remove model specific tokens | |
| current_word = None | |
| for token in soup.find_all('td')[-1].find_all('mark')[1:-1]: | |
| text = token.font.text.strip() | |
| if text.startswith(inside_token_prefix): | |
| text = text[len(inside_token_prefix):] | |
| else: | |
| # Create a new span for each word (sequence of sub-tokens) | |
| if current_word is not None: | |
| p.append(current_word) | |
| p.append(' ') | |
| current_word = soup.new_tag('span') | |
| token.string = text | |
| token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;" | |
| current_word.append(token) | |
| # Add last word | |
| p.append(current_word) | |
| # Add left and right-padding to each word | |
| for span in p.find_all('span'): | |
| span.find_all('mark')[0].attrs['style'] = ( | |
| f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;") | |
| span.find_all('mark')[-1].attrs['style'] = ( | |
| f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;") | |
| return p | |
| def list_all_article_names() -> list: | |
| filenames = [] | |
| for file in os.listdir('./sample-articles/'): | |
| if file.endswith('.txt'): | |
| filenames.append(file.replace('.txt', '')) | |
| return filenames | |
| def fetch_article_contents(filename: str) -> AnyStr: | |
| with open(f'./sample-articles/{filename.lower()}.txt', 'r') as f: | |
| data = f.read() | |
| return data | |
| def fetch_summary_contents(filename: str) -> AnyStr: | |
| with open(f'./sample-summaries/{filename.lower()}.txt', 'r') as f: | |
| data = f.read() | |
| return data | |
| def classify_comment(comment, selected_model): | |
| """Classify the given comment and augment with additional information.""" | |
| toxicity_pipeline, cls_explainer = load_pipeline(selected_model) | |
| result = toxicity_pipeline(comment)[0] | |
| result['model_name'] = selected_model | |
| # Add explanation | |
| result['word_attribution'] = cls_explainer(comment, class_name="non-toxic") | |
| result['visualitsation_html'] = cls_explainer.visualize()._repr_html_() | |
| result['tokens_with_background'] = format_explainer_html( | |
| result['visualitsation_html']) | |
| # Choose emoji reaction | |
| label, score = result['label'], result['score'] | |
| if label == 'toxic' and score > 0.1: | |
| emoji = random.choice(potty_mouth_emojis) | |
| elif label in ['non_toxic', 'non-toxic'] and score > 0.1: | |
| emoji = random.choice(regular_emojis) | |
| else: | |
| emoji = random.choice(undecided_emojis) | |
| result.update({'text': comment, 'emoji': emoji}) | |
| # Add result to session | |
| st.session_state.results.append(result) | |
| # Start session | |
| if 'results' not in st.session_state: | |
| st.session_state.results = [] | |
| # Page | |
| # st.title('π€¬ Dutch Toxic Comment Detection') | |
| # st.markdown("""This demo showcases two Dutch toxic comment detection models.""") | |
| # | |
| # # Introduction | |
| # st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments. | |
| # The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""") | |
| # st.markdown(f"""For a more comprehensive overview of the models check out their model card on π€ Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}). | |
| # """) | |
| # st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision: | |
| # <font color="black"> | |
| # <span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span> | |
| # </font> | |
| # tokens indicate toxicity whereas | |
| # <font color="black"> | |
| # <span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span> | |
| # </font> tokens indicate the opposite. | |
| # | |
| # Try it yourself! π""", | |
| # unsafe_allow_html=True) | |
| # Demo | |
| # with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True): | |
| # selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(), | |
| # )#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False) | |
| # text = st.text_area( | |
| # label='Enter the comment you want to classify below (in Dutch):') | |
| # _, rightmost_col = st.columns([6,1]) | |
| # submitted = rightmost_col.form_submit_button("Classify", | |
| # help="Classify comment") | |
| # TODO: should probably set a minimum length of article or something | |
| selected_article = st.selectbox('Select an article or provide your own:', | |
| list_all_article_names()) # index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False) | |
| st.session_state.article_text = fetch_article_contents(selected_article) | |
| article_text = st.text_area( | |
| label='Full article text', | |
| value=st.session_state.article_text, | |
| height=250 | |
| ) | |
| # _, rightmost_col = st.columns([5, 1]) | |
| # get_summary = rightmost_col.button("Generate summary", | |
| # help="Generate summary for the given article text") | |
| def display_summary(article_name: str): | |
| st.subheader("Generated summary") | |
| # st.markdown("######") | |
| summary_content = fetch_summary_contents(article_name) | |
| soup = BeautifulSoup(summary_content, features="html.parser") | |
| HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>""" | |
| st.session_state.summary_output = HTML_WRAPPER.format(soup) | |
| st.write(st.session_state.summary_output, unsafe_allow_html=True) | |
| # TODO: this functionality can be cached (e.g. by storing html file output) if wanted (or just store list of entities idk) | |
| def get_and_compare_entities_spacy(article_name: str): | |
| nlp = spacy.load('en_core_web_lg') | |
| article_content = fetch_article_contents(article_name) | |
| doc = nlp(article_content) | |
| # entities_article = doc.ents | |
| entities_article = [] | |
| for entity in doc.ents: | |
| entities_article.append(str(entity)) | |
| summary_content = fetch_summary_contents(article_name) | |
| doc = nlp(summary_content) | |
| # entities_summary = doc.ents | |
| entities_summary = [] | |
| for entity in doc.ents: | |
| entities_summary.append(str(entity)) | |
| matched_entities = [] | |
| unmatched_entities = [] | |
| for entity in entities_summary: | |
| # TODO: currently substring matching but probably should do embedding method or idk? | |
| if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): | |
| matched_entities.append(entity) | |
| else: | |
| unmatched_entities.append(entity) | |
| # print(entities_article) | |
| # print(entities_summary) | |
| return matched_entities, unmatched_entities | |
| def get_and_compare_entities_flair(article_name: str): | |
| nlp = spacy.load('en_core_web_sm') | |
| tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast") | |
| article_content = fetch_article_contents(article_name) | |
| doc = nlp(article_content) | |
| entities_article = [] | |
| sentences = list(doc.sents) | |
| for sentence in sentences: | |
| sentence_entities = Sentence(str(sentence)) | |
| tagger.predict(sentence_entities) | |
| for entity in sentence_entities.get_spans('ner'): | |
| entities_article.append(entity.text) | |
| summary_content = fetch_summary_contents(article_name) | |
| doc = nlp(summary_content) | |
| entities_summary = [] | |
| sentences = list(doc.sents) | |
| for sentence in sentences: | |
| sentence_entities = Sentence(str(sentence)) | |
| tagger.predict(sentence_entities) | |
| for entity in sentence_entities.get_spans('ner'): | |
| entities_summary.append(entity.text) | |
| matched_entities = [] | |
| unmatched_entities = [] | |
| for entity in entities_summary: | |
| # TODO: currently substring matching but probably should do embedding method or idk? | |
| if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): | |
| matched_entities.append(entity) | |
| else: | |
| unmatched_entities.append(entity) | |
| # print(entities_article) | |
| # print(entities_summary) | |
| return matched_entities, unmatched_entities | |
| def highlight_entities(article_name: str): | |
| st.subheader("Match entities with article") | |
| # st.markdown("####") | |
| summary_content = fetch_summary_contents(article_name) | |
| markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" | |
| markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" | |
| markdown_end = "</mark>" | |
| matched_entities, unmatched_entities = get_and_compare_entities_spacy(article_name) | |
| for entity in matched_entities: | |
| summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) | |
| for entity in unmatched_entities: | |
| summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) | |
| soup = BeautifulSoup(summary_content, features="html.parser") | |
| HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>""" | |
| st.write(HTML_WRAPPER.format(soup), unsafe_allow_html=True) | |
| def render_dependency_parsing(text: str): | |
| nlp = spacy.load('en_core_web_sm') | |
| #doc = nlp(text) | |
| # st.write(displacy.render(doc, style='dep')) | |
| #sentence_spans = list(doc.sents) | |
| # dep_svg = displacy.serve(sentence_spans, style="dep") | |
| # dep_svg = displacy.render(doc, style="dep", jupyter = False, | |
| # options = {"compact" : False,}) | |
| # st.image(dep_svg, width = 50,use_column_width=True) | |
| #visualize_parser(doc) | |
| #docs = [doc] | |
| #split_sents = True | |
| #docs = [span.as_doc() for span in doc.sents] if split_sents else [doc] | |
| #for sent in docs: | |
| html = render_sentence_custom(text) | |
| # Double newlines seem to mess with the rendering | |
| html = html.replace("\n\n", "\n") | |
| st.write(get_svg(html), unsafe_allow_html=True) | |
| #st.image(html, width=50, use_column_width=True) | |
| def check_dependency(text): | |
| tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast") | |
| nlp = spacy.load('en_core_web_lg') | |
| doc = nlp(text) | |
| tok_l = doc.to_json()['tokens'] | |
| # all_deps = [] | |
| all_deps = "" | |
| sentences = list(doc.sents) | |
| for sentence in sentences: | |
| all_entities = [] | |
| # # ENTITIES WITH SPACY: | |
| for entity in sentence.ents: | |
| all_entities.append(str(entity)) | |
| # # ENTITIES WITH FLAIR: | |
| sentence_entities = Sentence(str(sentence)) | |
| tagger.predict(sentence_entities) | |
| for entity in sentence_entities.get_spans('ner'): | |
| all_entities.append(entity.text) | |
| # ENTITIES WITH XLM ROBERTA | |
| # entities_xlm = [entity["word"] for entity in ner_model(str(sentence))] | |
| # for entity in entities_xlm: | |
| # all_entities.append(str(entity)) | |
| start_id = sentence.start | |
| end_id = sentence.end | |
| for t in tok_l: | |
| if t["id"] < start_id or t["id"] > end_id: | |
| continue | |
| head = tok_l[t['head']] | |
| if t['dep'] == 'amod': | |
| object_here = text[t['start']:t['end']] | |
| object_target = text[head['start']:head['end']] | |
| # ONE NEEDS TO BE ENTITY | |
| if (object_here in all_entities): | |
| # all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'") | |
| all_deps = all_deps.join(str(sentence)) | |
| elif (object_target in all_entities): | |
| # all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'") | |
| all_deps = all_deps.join(str(sentence)) | |
| else: | |
| continue | |
| return all_deps | |
| with st.form("article-input"): | |
| left_column, _ = st.columns([1, 1]) | |
| get_summary = left_column.form_submit_button("Generate summary", | |
| help="Generate summary for the given article text") | |
| # Listener | |
| if get_summary: | |
| if article_text: | |
| with st.spinner('Generating summary...'): | |
| # classify_comment(article_text, selected_model) | |
| display_summary(selected_article) | |
| else: | |
| st.error('**Error**: No comment to classify. Please provide a comment.') | |
| # Entity part | |
| with st.form("Entity-part"): | |
| left_column, _ = st.columns([1, 1]) | |
| draw_entities = left_column.form_submit_button("Draw Entities", | |
| help="Draw Entities") | |
| if draw_entities: | |
| with st.spinner("Drawing entities..."): | |
| highlight_entities(selected_article) | |
| with st.form("Dependency-usage"): | |
| left_column, _ = st.columns([1, 1]) | |
| parsing = left_column.form_submit_button("Dependency parsing", | |
| help="Dependency parsing") | |
| if parsing: | |
| with st.spinner("Doing dependency parsing..."): | |
| render_dependency_parsing(check_dependency(fetch_summary_contents(selected_article))) | |
| # Results | |
| # if 'results' in st.session_state and st.session_state.results: | |
| # first = True | |
| # for result in st.session_state.results[::-1]: | |
| # if not first: | |
| # st.markdown("---") | |
| # st.markdown(f"Text:\n> {result['text']}") | |
| # col_1, col_2, col_3 = st.columns([1,2,2]) | |
| # col_1.metric(label='', value=f"{result['emoji']}") | |
| # col_2.metric(label='Label', value=f"{result['label']}") | |
| # col_3.metric(label='Score', value=f"{result['score']:.3f}") | |
| # st.markdown(f"Token Attribution:\n{result['tokens_with_background']}", | |
| # unsafe_allow_html=True) | |
| # st.caption(f"Model: {result['model_name']}") | |
| # first = False | |