Updated TED method
This commit is contained in:
@@ -656,6 +656,223 @@
|
||||
"results, methods = evaluate_baseline_methods(test_pairs)\n",
|
||||
"print_comparison_table(results, test_pairs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "06db7e83",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tree edit Distance analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c62457eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"\n",
|
||||
"# Load spaCy model for dependency parsing\n",
|
||||
"nlp = spacy.load(\"en_core_web_lg\")\n",
|
||||
"print(\"Loaded spaCy model for Tree Edit Distance\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5905e3b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Build trees"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2aea0b77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_tree_from_dependencies(text):\n",
|
||||
" \"\"\"Build hierarchical tree structure from dependency parse.\"\"\"\n",
|
||||
" doc = nlp(text)\n",
|
||||
" \n",
|
||||
" # Create node dictionary with parent-child relationships\n",
|
||||
" tree = {}\n",
|
||||
" root_token = None\n",
|
||||
" \n",
|
||||
" for token in doc:\n",
|
||||
" tree[token.i] = {\n",
|
||||
" 'text': token.text,\n",
|
||||
" 'lemma': token.lemma_,\n",
|
||||
" 'pos': token.pos_,\n",
|
||||
" 'dep': token.dep_,\n",
|
||||
" 'head_id': token.head.i,\n",
|
||||
" 'children': []\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" # Build parent-child relationships\n",
|
||||
" for token in doc:\n",
|
||||
" if token.head.i != token.i: # Not root\n",
|
||||
" tree[token.head.i]['children'].append(token.i)\n",
|
||||
" else: # Root node\n",
|
||||
" root_token = token.i\n",
|
||||
" \n",
|
||||
" return tree, root_token\n",
|
||||
"\n",
|
||||
"print(\"Tree building function loaded\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b45d5916",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Helper Functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a5e5fa5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_nodes_postorder(node_id, tree):\n",
|
||||
" \"\"\"Get all nodes in postorder traversal (children before parents).\"\"\"\n",
|
||||
" nodes = []\n",
|
||||
" if node_id in tree:\n",
|
||||
" for child_id in tree[node_id]['children']:\n",
|
||||
" nodes.extend(get_nodes_postorder(child_id, tree))\n",
|
||||
" nodes.append(node_id)\n",
|
||||
" return nodes\n",
|
||||
"\n",
|
||||
"def node_label(node_id, tree):\n",
|
||||
" \"\"\"Get label for a node (POS tag + dependency relation).\"\"\"\n",
|
||||
" if node_id in tree:\n",
|
||||
" return (tree[node_id]['pos'], tree[node_id]['dep'])\n",
|
||||
" return ('', '')\n",
|
||||
"\n",
|
||||
"def nodes_match(node_id1, tree1, node_id2, tree2):\n",
|
||||
" \"\"\"Check if two nodes have matching labels.\"\"\"\n",
|
||||
" label1 = node_label(node_id1, tree1)\n",
|
||||
" label2 = node_label(node_id2, tree2)\n",
|
||||
" return label1 == label2\n",
|
||||
"\n",
|
||||
"print(\"Helper functions loaded\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5176cf24",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Tree edit distance function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ab8c40c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tree_edit_distance_zhang_shasha(tree1, tree2, root1, root2):\n",
|
||||
" \"\"\"\n",
|
||||
" Full Tree Edit Distance using Zhang-Shasha algorithm.\n",
|
||||
" Preserves hierarchical structure and node ordering.\n",
|
||||
" Returns: Normalized similarity score (0-1)\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" # get postorder traversals\n",
|
||||
" nodes1 = get_nodes_postorder(root1, tree1)\n",
|
||||
" nodes2 = get_nodes_postorder(root2, tree2)\n",
|
||||
" \n",
|
||||
" m, n = len(nodes1), len(nodes2)\n",
|
||||
" \n",
|
||||
" # DP table\n",
|
||||
" dp = [[0] * (n + 1) for _ in range(m + 1)]\n",
|
||||
" \n",
|
||||
" # Base cases\n",
|
||||
" for i in range(m + 1):\n",
|
||||
" dp[i][0] = i\n",
|
||||
" for j in range(n + 1):\n",
|
||||
" dp[0][j] = j\n",
|
||||
" \n",
|
||||
" # Fill DP table\n",
|
||||
" for i in range(1, m + 1):\n",
|
||||
" for j in range(1, n + 1):\n",
|
||||
" node1 = nodes1[i - 1]\n",
|
||||
" node2 = nodes2[j - 1]\n",
|
||||
" \n",
|
||||
" if nodes_match(node1, tree1, node2, tree2):\n",
|
||||
" cost = dp[i - 1][j - 1]\n",
|
||||
" else:\n",
|
||||
" delete_cost = dp[i - 1][j] + 1\n",
|
||||
" insert_cost = dp[i][j - 1] + 1\n",
|
||||
" replace_cost = dp[i - 1][j - 1] + 1\n",
|
||||
" cost = min(delete_cost, insert_cost, replace_cost)\n",
|
||||
" \n",
|
||||
" dp[i][j] = cost\n",
|
||||
" \n",
|
||||
" # convert to similarity\n",
|
||||
" max_size = max(m, n)\n",
|
||||
" if max_size == 0:\n",
|
||||
" return 1.0\n",
|
||||
" \n",
|
||||
" edit_distance = dp[m][n]\n",
|
||||
" similarity = 1.0 - (edit_distance / max_size)\n",
|
||||
" \n",
|
||||
" return max(0.0, similarity)\n",
|
||||
"\n",
|
||||
"def tree_edit_distance_similarity(sent1, sent2):\n",
|
||||
" \"\"\"Wrapper function for Tree Edit Distance similarity.\"\"\"\n",
|
||||
" if not sent1.strip() or not sent2.strip():\n",
|
||||
" return 0.0\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" tree1, root1 = build_tree_from_dependencies(sent1)\n",
|
||||
" tree2, root2 = build_tree_from_dependencies(sent2)\n",
|
||||
" return tree_edit_distance_zhang_shasha(tree1, tree2, root1, root2)\n",
|
||||
" except:\n",
|
||||
" return 0.0\n",
|
||||
"\n",
|
||||
"print(\"Tree Edit Distance algorithm loaded\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3bfc1dcd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Standalone running TED"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80f72ef3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"Tree Edit Distance - Standalone Tests\")\n",
|
||||
"print(\"=\" * 100)\n",
|
||||
"\n",
|
||||
"test_cases = [\n",
|
||||
" (\"The cat sat on the mat.\", \"The cat sat on the mat.\"), # Exact copy\n",
|
||||
" (\"The cat sat on the mat.\", \"On the mat, the cat was sitting.\"), # Structural change\n",
|
||||
" (\"The cat sat on the mat.\", \"The feline rested on the rug.\"), # Synonym replacement\n",
|
||||
" (\"The cat sat on the mat.\", \"The dog ran in the park.\"), # Different content\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for sent1, sent2 in test_cases:\n",
|
||||
" sim = tree_edit_distance_similarity(sent1, sent2)\n",
|
||||
" print(f\"\\n{sent1:<45} vs\")\n",
|
||||
" print(f\"{sent2:<45}\")\n",
|
||||
" print(f\"TED Similarity: {sim:.3f}\")\n",
|
||||
" print(\"-\" * 100)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -674,7 +891,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.12"
|
||||
"version": "3.14.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user