Source code for matengine.generation.decisiontree

import numpy as np

[docs] class decision_tree: """ A class to represent a decision tree for making decisions based on user-defined criteria. """ def __init__(self): """ Initialize the decision tree with an empty structure. """ self.tree = None
[docs] def build_tree(self, instructions): """ Build a decision tree based on the given instructions. Parameters: instructions: dict A dictionary containing details about the nodes, including type, function, arguments, and branches. Notes: - Nodes can be either decision nodes or leaf nodes. - The 'root' node is assumed to be the starting point. """ nodes = {} for node_id, details in instructions.items(): if details['type'] == 'decision': nodes[node_id] = self.DecisionNode( func=details['func'], args=details['args'] ) elif details['type'] == 'leaf': nodes[node_id] = self.LeafNode(details['action']) for node_id, details in instructions.items(): if details['type'] == 'decision': nodes[node_id].yes_branch = nodes.get(details.get('yes_branch')) nodes[node_id].no_branch = nodes.get(details.get('no_branch')) self.tree = nodes['root'] # Assuming 'root' is the entry point defined in your instructions
[docs] def decide(self, data): """ Make a decision using the decision tree. Parameters: data: dict The input data used for making a decision. Returns: The action to be taken as per the leaf node reached. Raises: ValueError: If the decision tree has not been built yet. """ if self.tree: return self.tree.decide(data) else: raise ValueError("The decision tree has not been built yet.")
[docs] class DecisionNode: """ A class to represent a decision node in the decision tree. """ def __init__(self, func, args, yes_branch=None, no_branch=None): """ Initialize a decision node. Parameters: func: function The function used to evaluate the decision. args: dict Arguments required by the function. yes_branch: DecisionNode or LeafNode, optional The branch to follow if the decision is True. no_branch: DecisionNode or LeafNode, optional The branch to follow if the decision is False. """ self.func = func self.args = args self.yes_branch = yes_branch self.no_branch = no_branch
[docs] def decide(self, data): """ Make a decision at this node based on the input data. Parameters: data: dict The input data used for making a decision. Returns: The action or next node based on the function's evaluation. """ if self.func(data, **self.args): return self.yes_branch.decide(data) if self.yes_branch else None else: return self.no_branch.decide(data) if self.no_branch else None
[docs] class LeafNode: """ A class to represent a leaf node in the decision tree. """ def __init__(self, action): """ Initialize a leaf node. Parameters: action: The action to be returned when this leaf node is reached. """ self.action = action
[docs] def decide(self, data): """ Return the action associated with this leaf node. Parameters: data: dict The input data used for making a decision. Returns: The action stored in this leaf node. """ return self.action
[docs] def less_than(data, key, threshold): """ Determine if the value in the data is less than or equal to the threshold. Parameters: data: dict The input data. key: str The key whose value is being compared. threshold: numeric The threshold value. Returns: bool: True if the value is less than or equal to the threshold, False otherwise. """ return data[key] <= threshold
[docs] def greater_than(data, key, threshold): """ Determine if the value in the data is greater than the threshold. Parameters: data: dict The input data. key: str The key whose value is being compared. threshold: numeric The threshold value. Returns: bool: True if the value is greater than the threshold, False otherwise. """ return data[key] > threshold
[docs] def linear(data, key1, key2, m, c): """ Determine if the value lies on or above the line defined by the equation y = mx + c. Parameters: data: dict The input data. key1: str The key representing the x-value. key2: str` The key representing the y-value. m: float The slope of the line. c: float The y-intercept of the line. Returns: bool: True if the point is on or above the line, False otherwise. """ return data[key2] - m*data[key1] - c >= 0
[docs] def ellipse(data, key1, key2, cx, cy, sx, sy, angle=0): """ Determine if a point lies inside or on the boundary of an ellipse. Parameters: data: dict The input data. key1: str The key representing the x-coordinate of the point. key2: str The key representing the y-coordinate of the point. cx: float The x-coordinate of the ellipse center. cy: float The y-coordinate of the ellipse center. sx: float The semi-axis length in the x-direction. sy: float The semi-axis length in the y-direction. angle: float, optional The rotation angle of the ellipse in degrees. Default is 0. Returns: bool: True if the point lies within or on the ellipse, False otherwise. """ # Convert angle to radians angle_rad = np.radians(angle) # Extract point coordinates and center them x = data[key1] - cx y = data[key2] - cy # Apply rotation x_rot = x * np.cos(angle_rad) + y * np.sin(angle_rad) y_rot = -x * np.sin(angle_rad) + y * np.cos(angle_rad) # Check ellipse condition with rotated coordinates return (x_rot / sx) ** 2 + (y_rot / sy) ** 2 <= 1
[docs] def ellipsoid(data, key1, key2, key3, cx, cy, cz, sx, sy, sz): """ Determine if a point lies inside or on the boundary of an ellipsoid. Parameters: data: dict The input data. key1: str The key representing the x-coordinate of the point. key2: str The key representing the y-coordinate of the point. key3: str The key representing the z-coordinate of the point. cx: float The x-coordinate of the ellipsoid center. cy: float The y-coordinate of the ellipsoid center. cz: float The z-coordinate of the ellipsoid center. sx: float The semi-axis length in the x-direction. sy: float The semi-axis length in the y-direction. sz: float The semi-axis length in the z-direction. Returns: bool: True if the point lies within or on the ellipsoid, False otherwise. """ return ((data[key1]-cx)/sx)**2 + ((data[key2]-cy)/sy)**2 + ((data[key3]-cz)/sz)**2 <= 1